1593 lines
82 KiB
Python
1593 lines
82 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Generation configuration class and utilities."""
|
|
|
|
import copy
|
|
import json
|
|
import os
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, is_dataclass
|
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
|
|
from .. import __version__
|
|
from ..configuration_utils import PretrainedConfig
|
|
from ..utils import (
|
|
GENERATION_CONFIG_NAME,
|
|
ExplicitEnum,
|
|
PushToHubMixin,
|
|
cached_file,
|
|
download_url,
|
|
extract_commit_hash,
|
|
is_remote_url,
|
|
is_torch_available,
|
|
logging,
|
|
)
|
|
from ..utils.deprecation import deprecate_kwarg
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from ..modeling_utils import PreTrainedModel
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
|
|
NEED_SETUP_CACHE_CLASSES_MAPPING = {}
|
|
QUANT_BACKEND_CLASSES_MAPPING = {}
|
|
ALL_CACHE_IMPLEMENTATIONS = []
|
|
|
|
if is_torch_available():
|
|
from ..cache_utils import (
|
|
HQQQuantizedCache,
|
|
HybridCache,
|
|
HybridChunkedCache,
|
|
OffloadedHybridCache,
|
|
OffloadedStaticCache,
|
|
QuantoQuantizedCache,
|
|
SlidingWindowCache,
|
|
StaticCache,
|
|
)
|
|
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
|
|
|
|
NEED_SETUP_CACHE_CLASSES_MAPPING = {
|
|
"static": StaticCache,
|
|
"offloaded_static": OffloadedStaticCache,
|
|
"sliding_window": SlidingWindowCache,
|
|
"hybrid": HybridCache,
|
|
"hybrid_chunked": HybridChunkedCache,
|
|
"offloaded_hybrid": OffloadedHybridCache,
|
|
"offloaded_hybrid_chunked": OffloadedHybridCache,
|
|
}
|
|
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
|
ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + ["offloaded", "dynamic", "quantized"]
|
|
|
|
|
|
class GenerationMode(ExplicitEnum):
|
|
"""
|
|
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
|
|
"""
|
|
|
|
# Non-beam methods
|
|
CONTRASTIVE_SEARCH = "contrastive_search"
|
|
GREEDY_SEARCH = "greedy_search"
|
|
SAMPLE = "sample"
|
|
ASSISTED_GENERATION = "assisted_generation"
|
|
DOLA_GENERATION = "dola_generation"
|
|
# Beam methods
|
|
BEAM_SEARCH = "beam_search"
|
|
BEAM_SAMPLE = "beam_sample"
|
|
CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
|
|
GROUP_BEAM_SEARCH = "group_beam_search"
|
|
|
|
|
|
class GenerationConfig(PushToHubMixin):
|
|
# no-format
|
|
"""
|
|
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
|
|
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
|
|
|
|
- *greedy decoding* if `num_beams=1` and `do_sample=False`
|
|
- *contrastive search* if `penalty_alpha>0.` and `top_k>1`
|
|
- *multinomial sampling* if `num_beams=1` and `do_sample=True`
|
|
- *beam-search decoding* if `num_beams>1` and `do_sample=False`
|
|
- *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
|
|
- *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1`
|
|
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
|
|
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
|
|
- *dola decoding* if `dola_layers` is passed to `.generate()`
|
|
|
|
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
|
|
|
<Tip>
|
|
|
|
A large number of these flags control the logits or the stopping criteria of the generation. Make sure you check
|
|
the [generate-related classes](https://huggingface.co/docs/transformers/internal/generation_utils) for a full
|
|
description of the possible manipulations, as well as examples of their usage.
|
|
|
|
</Tip>
|
|
|
|
Arg:
|
|
> Parameters that control the length of the output
|
|
|
|
max_length (`int`, *optional*, defaults to 20):
|
|
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
|
|
`max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
|
|
max_new_tokens (`int`, *optional*):
|
|
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
|
min_length (`int`, *optional*, defaults to 0):
|
|
The minimum length of the sequence to be generated. Corresponds to the length of the input prompt +
|
|
`min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
|
|
min_new_tokens (`int`, *optional*):
|
|
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
|
early_stopping (`bool` or `str`, *optional*, defaults to `False`):
|
|
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
|
|
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
|
|
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
|
|
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
|
|
beam search algorithm).
|
|
max_time (`float`, *optional*):
|
|
The maximum amount of time you allow the computation to run for in seconds. generation will still finish
|
|
the current pass after allocated time has been passed.
|
|
stop_strings (`str or list[str]`, *optional*):
|
|
A string or a list of strings that should terminate generation if the model outputs them.
|
|
|
|
> Parameters that control the generation strategy used
|
|
|
|
do_sample (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to use sampling ; use greedy decoding otherwise.
|
|
num_beams (`int`, *optional*, defaults to 1):
|
|
Number of beams for beam search. 1 means no beam search.
|
|
num_beam_groups (`int`, *optional*, defaults to 1):
|
|
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
|
[this paper](https://huggingface.co/papers/1610.02424) for more details.
|
|
penalty_alpha (`float`, *optional*):
|
|
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
|
|
dola_layers (`str` or `list[int]`, *optional*):
|
|
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
|
|
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
|
|
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
|
|
layers up to the last 20 layers.
|
|
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
|
|
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
|
|
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
|
|
or [the paper](https://huggingface.co/papers/2309.03883) for more details.
|
|
|
|
> Parameters that control the cache
|
|
|
|
use_cache (`bool`, *optional*, defaults to `True`):
|
|
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
|
speed up decoding.
|
|
cache_implementation (`str`, *optional*, default to `None`):
|
|
Name of the cache class that will be instantiated in `generate`, for faster decoding. Possible values are:
|
|
|
|
- `"dynamic"`: [`DynamicCache`]
|
|
- `"static"`: [`StaticCache`]
|
|
- `"offloaded_static"`: [`OffloadedStaticCache`]
|
|
- `"sliding_window"`: [`SlidingWindowCache`]
|
|
- `"hybrid"`: [`HybridCache`]
|
|
- `"quantized"`: [`QuantizedCache`]
|
|
|
|
If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See
|
|
our [cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
|
|
cache_config (`dict`, *optional*, default to `None`):
|
|
Arguments used in the key-value cache class can be passed in `cache_config`.
|
|
return_legacy_cache (`bool`, *optional*, default to `True`):
|
|
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
|
|
|
|
> Parameters for manipulation of the model output logits
|
|
|
|
temperature (`float`, *optional*, defaults to 1.0):
|
|
The value used to module the next token probabilities. This value is set in a model's `generation_config.json` file. If it isn't set, the default value is 1.0
|
|
top_k (`int`, *optional*, defaults to 50):
|
|
The number of highest probability vocabulary tokens to keep for top-k-filtering. This value is set in a model's `generation_config.json` file. If it isn't set, the default value is 50.
|
|
top_p (`float`, *optional*, defaults to 1.0):
|
|
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to
|
|
`top_p` or higher are kept for generation. This value is set in a model's `generation_config.json` file. If it isn't set, the default value is 1.0
|
|
min_p (`float`, *optional*):
|
|
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
|
|
value between 0 and 1. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in
|
|
the 0.99-0.8 range (use the opposite of normal `top_p` values).
|
|
typical_p (`float`, *optional*, defaults to 1.0):
|
|
Local typicality measures how similar the conditional probability of predicting a target token next is to
|
|
the expected conditional probability of predicting a random token next, given the partial text already
|
|
generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that
|
|
add up to `typical_p` or higher are kept for generation. See [this
|
|
paper](https://huggingface.co/papers/2202.00666) for more details.
|
|
epsilon_cutoff (`float`, *optional*, defaults to 0.0):
|
|
If set to float strictly between 0 and 1, only tokens with a conditional probability greater than
|
|
`epsilon_cutoff` will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the
|
|
size of the model. See [Truncation Sampling as Language Model
|
|
Desmoothing](https://huggingface.co/papers/2210.15191) for more details.
|
|
eta_cutoff (`float`, *optional*, defaults to 0.0):
|
|
Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between
|
|
0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) *
|
|
exp(-entropy(softmax(next_token_logits)))`. The latter term is intuitively the expected next token
|
|
probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3,
|
|
depending on the size of the model. See [Truncation Sampling as Language Model
|
|
Desmoothing](https://huggingface.co/papers/2210.15191) for more details.
|
|
diversity_penalty (`float`, *optional*, defaults to 0.0):
|
|
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
|
|
particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled.
|
|
repetition_penalty (`float`, *optional*, defaults to 1.0):
|
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
|
paper](https://huggingface.co/papers/1909.05858) for more details.
|
|
encoder_repetition_penalty (`float`, *optional*, defaults to 1.0):
|
|
The parameter for encoder_repetition_penalty. An exponential penalty on sequences that are not in the
|
|
original input. 1.0 means no penalty.
|
|
length_penalty (`float`, *optional*, defaults to 1.0):
|
|
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
|
|
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
|
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
|
|
`length_penalty` < 0.0 encourages shorter sequences.
|
|
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
|
|
If set to int > 0, all ngrams of that size can only occur once.
|
|
bad_words_ids (`list[list[int]]`, *optional*):
|
|
List of list of token ids that are not allowed to be generated. Check
|
|
[`~generation.NoBadWordsLogitsProcessor`] for further documentation and examples.
|
|
force_words_ids (`list[list[int]]` or `list[list[list[int]]]`, *optional*):
|
|
List of token ids that must be generated. If given a `list[list[int]]`, this is treated as a simple list of
|
|
words that must be included, the opposite to `bad_words_ids`. If given `list[list[list[int]]]`, this
|
|
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
|
|
can allow different forms of each word.
|
|
renormalize_logits (`bool`, *optional*, defaults to `False`):
|
|
Whether to renormalize the logits after applying all the logits processors (including the custom
|
|
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
|
|
are normalized but some logit processors break the normalization.
|
|
constraints (`list[Constraint]`, *optional*):
|
|
Custom constraints that can be added to the generation to ensure that the output will contain the use of
|
|
certain tokens as defined by `Constraint` objects, in the most sensible way possible.
|
|
forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
|
|
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
|
|
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
|
|
language token.
|
|
forced_eos_token_id (`int` or list[int]`, *optional*, defaults to `model.config.forced_eos_token_id`):
|
|
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
|
|
list to set multiple *end-of-sequence* tokens.
|
|
remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
|
|
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash.
|
|
Note that using `remove_invalid_values` can slow down generation.
|
|
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
|
|
This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
|
|
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where
|
|
penalty starts and `decay_factor` represents the factor of exponential decay
|
|
suppress_tokens (`list[int]`, *optional*):
|
|
A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their
|
|
log probs to `-inf` so that they are not sampled.
|
|
begin_suppress_tokens (`list[int]`, *optional*):
|
|
A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
|
|
processor will set their log probs to `-inf` so that they are not sampled.
|
|
sequence_bias (`dict[tuple[int], float]`, *optional*)):
|
|
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
|
sequence being selected, while negative biases do the opposite. Check
|
|
[`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples.
|
|
token_healing (`bool`, *optional*, defaults to `False`):
|
|
Heal tail tokens of prompts by replacing them with their appropriate extensions.
|
|
This enhances the quality of completions for prompts affected by greedy tokenization bias.
|
|
guidance_scale (`float`, *optional*):
|
|
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
|
|
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
|
prompt, usually at the expense of poorer quality.
|
|
low_memory (`bool`, *optional*):
|
|
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
|
|
Used with beam search and contrastive search.
|
|
watermarking_config (`BaseWatermarkingConfig` or `dict`, *optional*):
|
|
Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green"
|
|
tokens. See the docs of [`SynthIDTextWatermarkingConfig`] and [`WatermarkingConfig`] for more
|
|
details. If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally.
|
|
|
|
> Parameters that define the output variables of generate
|
|
|
|
num_return_sequences (`int`, *optional*, defaults to 1):
|
|
The number of independently computed returned sequences for each element in the batch.
|
|
output_attentions (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
tensors for more details.
|
|
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
more details.
|
|
output_scores (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
|
output_logits (`bool`, *optional*):
|
|
Whether or not to return the unprocessed prediction logit scores. See `logits` under returned tensors for
|
|
more details.
|
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to return a [`~utils.ModelOutput`], as opposed to returning exclusively the generated
|
|
sequence. This flag must be set to `True` to return the generation cache (when `use_cache` is `True`)
|
|
or optional outputs (see flags starting with `output_`)
|
|
|
|
> Special tokens that can be used at generation time
|
|
|
|
pad_token_id (`int`, *optional*):
|
|
The id of the *padding* token.
|
|
bos_token_id (`int`, *optional*):
|
|
The id of the *beginning-of-sequence* token.
|
|
eos_token_id (`Union[int, list[int]]`, *optional*):
|
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
|
|
|
> Generation parameters exclusive to encoder-decoder models
|
|
|
|
encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
|
|
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
|
|
`decoder_input_ids`.
|
|
decoder_start_token_id (`int` or `list[int]`, *optional*):
|
|
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token or a list of length
|
|
`batch_size`. Indicating a list enables different start ids for each element in the batch
|
|
(e.g. multilingual models with different target languages in one batch)
|
|
|
|
> Generation parameters exclusive to assistant generation
|
|
is_assistant (`bool`, *optional*, defaults to `False`):
|
|
Whether the model is an assistant (draft) model.
|
|
num_assistant_tokens (`int`, *optional*, defaults to 20):
|
|
Defines the number of _speculative tokens_ that shall be generated by the assistant model before being
|
|
checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation
|
|
more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant
|
|
model requires lots of corrections, lower speed-ups are reached.
|
|
num_assistant_tokens_schedule (`str`, *optional*, defaults to `"constant"`):
|
|
Defines the schedule at which max assistant tokens shall be changed during inference.
|
|
- `"heuristic"`: When all speculative tokens are correct, increase `num_assistant_tokens` by 2 else
|
|
reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model.
|
|
- `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call.
|
|
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
|
|
assistant_confidence_threshold (`float`, *optional*, defaults to 0.4):
|
|
The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
|
|
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
|
|
(defined by `num_assistant_tokens`) is not yet reached. The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes, biased towards avoiding false negatives.
|
|
`assistant_confidence_threshold` value is persistent over multiple generation calls with the same assistant model.
|
|
It is an unsupervised version of the dynamic speculation lookahead
|
|
from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models <https://huggingface.co/papers/2405.04304>.
|
|
prompt_lookup_num_tokens (`int`, *optional*):
|
|
The number of tokens to be output as candidate tokens.
|
|
max_matching_ngram_size (`int`, *optional*):
|
|
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
|
|
assistant_early_exit(`int`, *optional*):
|
|
If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with
|
|
models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head).
|
|
assistant_lookbehind(`int`, *optional*, defaults to 10):
|
|
If set to a positive integer, the re-encodeing process will additionally consider the last `assistant_lookbehind` assistant tokens
|
|
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
|
|
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
|
|
target_lookbehind(`int`, *optional*, defaults to 10):
|
|
If set to a positive integer, the re-encodeing process will additionally consider the last `target_lookbehind` target tokens
|
|
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
|
|
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
|
|
|
|
> Parameters related to performances and compilation
|
|
|
|
compile_config (CompileConfig, *optional*):
|
|
If using a compilable cache, this controls how `generate` will `compile` the forward pass for faster
|
|
inference.
|
|
disable_compile (`bool`, *optional*):
|
|
Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when
|
|
specific criteria are met, including using a compilable cache. Please open an issue if you find the
|
|
need to use this flag.
|
|
"""
|
|
|
|
extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits")
|
|
|
|
def __init__(self, **kwargs):
|
|
# Parameters that control the length of the output
|
|
self.max_length = kwargs.pop("max_length", 20)
|
|
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
|
|
self.min_length = kwargs.pop("min_length", 0)
|
|
self.min_new_tokens = kwargs.pop("min_new_tokens", None)
|
|
self.early_stopping = kwargs.pop("early_stopping", False)
|
|
self.max_time = kwargs.pop("max_time", None)
|
|
self.stop_strings = kwargs.pop("stop_strings", None)
|
|
|
|
# Parameters that control the generation strategy used
|
|
self.do_sample = kwargs.pop("do_sample", False)
|
|
self.num_beams = kwargs.pop("num_beams", 1)
|
|
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
|
|
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
|
|
self.dola_layers = kwargs.pop("dola_layers", None)
|
|
|
|
# Parameters that control the cache
|
|
self.use_cache = kwargs.pop("use_cache", True)
|
|
self.cache_implementation = kwargs.pop("cache_implementation", None)
|
|
self.cache_config = kwargs.pop("cache_config", None)
|
|
if self.cache_config is not None and not isinstance(self.cache_config, dict):
|
|
warnings.warn(
|
|
(
|
|
"Passing a CacheConfig object is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary."
|
|
),
|
|
FutureWarning,
|
|
stacklevel=2,
|
|
)
|
|
self.cache_config = self.cache_config.to_dict()
|
|
|
|
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
|
|
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None)
|
|
|
|
# Parameters for manipulation of the model output logits
|
|
self.temperature = kwargs.pop("temperature", 1.0)
|
|
self.top_k = kwargs.pop("top_k", 50)
|
|
self.top_p = kwargs.pop("top_p", 1.0)
|
|
self.min_p = kwargs.pop("min_p", None)
|
|
self.typical_p = kwargs.pop("typical_p", 1.0)
|
|
self.epsilon_cutoff = kwargs.pop("epsilon_cutoff", 0.0)
|
|
self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0)
|
|
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
|
|
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
|
self.encoder_repetition_penalty = kwargs.pop("encoder_repetition_penalty", 1.0)
|
|
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
|
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
|
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
|
self.force_words_ids = kwargs.pop("force_words_ids", None)
|
|
self.renormalize_logits = kwargs.pop("renormalize_logits", False)
|
|
self.constraints = kwargs.pop("constraints", None)
|
|
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
|
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
|
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
|
|
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
|
|
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
|
|
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
|
|
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
|
self.token_healing = kwargs.pop("token_healing", False)
|
|
self.guidance_scale = kwargs.pop("guidance_scale", None)
|
|
self.low_memory = kwargs.pop("low_memory", None)
|
|
watermarking_config = kwargs.pop("watermarking_config", None)
|
|
if watermarking_config is None:
|
|
self.watermarking_config = None
|
|
elif isinstance(watermarking_config, BaseWatermarkingConfig):
|
|
self.watermarking_config = watermarking_config
|
|
else:
|
|
self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config)
|
|
|
|
# Parameters that define the output variables of `generate`
|
|
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
|
self.output_attentions = kwargs.pop("output_attentions", False)
|
|
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
|
self.output_scores = kwargs.pop("output_scores", False)
|
|
self.output_logits = kwargs.pop("output_logits", None)
|
|
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
|
|
|
# Special tokens that can be used at generation time
|
|
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
|
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
|
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
|
|
|
# Generation parameters exclusive to encoder-decoder models
|
|
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
|
|
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
|
|
|
# Assistant generation
|
|
self.is_assistant = False
|
|
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20)
|
|
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant")
|
|
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4)
|
|
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
|
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
|
|
self.assistant_early_exit = kwargs.pop("assistant_early_exit", None)
|
|
## assistant generation for different tokenizers, the windows size for assistant/target model
|
|
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
|
|
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
|
|
|
|
# Performance
|
|
self.compile_config = kwargs.pop("compile_config", None)
|
|
self.disable_compile = kwargs.pop("disable_compile", False)
|
|
|
|
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
|
|
# interface.
|
|
self._from_model_config = kwargs.pop("_from_model_config", False)
|
|
self._commit_hash = kwargs.pop("_commit_hash", None)
|
|
self.transformers_version = kwargs.pop("transformers_version", __version__)
|
|
|
|
# Additional attributes without default values
|
|
if not self._from_model_config:
|
|
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
|
|
# model's default configuration file
|
|
for key, value in kwargs.items():
|
|
try:
|
|
setattr(self, key, value)
|
|
except AttributeError as err:
|
|
logger.error(f"Can't set {key} with value {value} for {self}")
|
|
raise err
|
|
|
|
# Validate the values of the attributes
|
|
self.validate()
|
|
|
|
def __hash__(self):
|
|
return hash(self.to_json_string(ignore_metadata=True))
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, GenerationConfig):
|
|
return False
|
|
|
|
self_without_metadata = self.to_json_string(use_diff=False, ignore_metadata=True)
|
|
other_without_metadata = other.to_json_string(use_diff=False, ignore_metadata=True)
|
|
return self_without_metadata == other_without_metadata
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}"
|
|
|
|
def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = None) -> GenerationMode:
|
|
"""
|
|
Returns the generation mode triggered by the [`GenerationConfig`] instance.
|
|
|
|
Arg:
|
|
assistant_model (`PreTrainedModel`, *optional*):
|
|
The assistant model to be used for assisted generation. If set, the generation mode will be
|
|
assisted generation.
|
|
|
|
Returns:
|
|
`GenerationMode`: The generation mode triggered by the instance.
|
|
"""
|
|
# TODO joao: find out a way of not depending on external fields (e.g. `assistant_model`), then make this a
|
|
# property and part of the `__repr__`
|
|
if self.constraints is not None or self.force_words_ids is not None:
|
|
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
|
|
elif self.num_beams == 1:
|
|
if self.do_sample is False:
|
|
if (
|
|
self.top_k is not None
|
|
and self.top_k > 1
|
|
and self.penalty_alpha is not None
|
|
and self.penalty_alpha > 0
|
|
):
|
|
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
|
|
else:
|
|
generation_mode = GenerationMode.GREEDY_SEARCH
|
|
else:
|
|
generation_mode = GenerationMode.SAMPLE
|
|
else:
|
|
if self.num_beam_groups > 1:
|
|
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
|
|
elif self.do_sample is True:
|
|
generation_mode = GenerationMode.BEAM_SAMPLE
|
|
else:
|
|
generation_mode = GenerationMode.BEAM_SEARCH
|
|
|
|
# Assisted generation may extend some generation modes
|
|
if (
|
|
assistant_model is not None
|
|
or self.prompt_lookup_num_tokens is not None
|
|
or self.assistant_early_exit is not None
|
|
):
|
|
if generation_mode in ("greedy_search", "sample"):
|
|
generation_mode = GenerationMode.ASSISTED_GENERATION
|
|
else:
|
|
logger.warning(
|
|
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
|
"is only supported with Greedy Search and Sample. However, the base decoding mode (based on "
|
|
f"current flags) is {generation_mode} -- some of the set flags will be ignored."
|
|
)
|
|
|
|
# DoLa generation may extend some generation modes
|
|
if self.dola_layers is not None:
|
|
if generation_mode in ("greedy_search", "sample"):
|
|
generation_mode = GenerationMode.DOLA_GENERATION
|
|
else:
|
|
logger.warning(
|
|
"You've set `dola_layers`, which triggers DoLa generate. Currently, DoLa generate "
|
|
"is only supported with Greedy Search and Sample. However, the base decoding mode (based on "
|
|
f"current flags) is {generation_mode} -- some of the set flags will be ignored."
|
|
)
|
|
return generation_mode
|
|
|
|
@deprecate_kwarg("is_init", version="4.54.0")
|
|
def validate(self, strict=False):
|
|
"""
|
|
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
|
|
of parameterization that can be detected as incorrect from the configuration instance alone.
|
|
|
|
Note that some parameters not validated here are best validated at generate runtime, as they may depend on
|
|
other inputs and/or the model, such as parameters related to the generation length.
|
|
|
|
Args:
|
|
strict (bool): If True, raise an exception for any issues found. If False, only log issues.
|
|
"""
|
|
minor_issues = {} # format: {attribute_name: issue_description}
|
|
|
|
# 1. Validation of individual attributes
|
|
# 1.1. Decoding attributes
|
|
if self.early_stopping not in {True, False, "never"}:
|
|
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
|
|
if self.max_new_tokens is not None and self.max_new_tokens <= 0:
|
|
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
|
|
if self.pad_token_id is not None and self.pad_token_id < 0:
|
|
minor_issues["pad_token_id"] = (
|
|
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
|
|
"generating, if there is padding. Please set `pad_token_id` explicitly as "
|
|
"`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
|
|
)
|
|
# 1.2. Cache attributes
|
|
if self.cache_implementation is not None and self.cache_implementation not in ALL_CACHE_IMPLEMENTATIONS:
|
|
raise ValueError(
|
|
f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: "
|
|
f"{ALL_CACHE_IMPLEMENTATIONS}"
|
|
)
|
|
# 1.3. Performance attributes
|
|
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
|
|
raise ValueError(
|
|
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
|
|
"instance of `CompileConfig`."
|
|
)
|
|
# 1.4. Watermarking attributes
|
|
if self.watermarking_config is not None:
|
|
if not (
|
|
isinstance(self.watermarking_config, WatermarkingConfig)
|
|
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
|
|
):
|
|
minor_issues["watermarking_config"] = (
|
|
"`watermarking_config` as a dict is deprecated and will be removed in v4.54.0. Please construct "
|
|
"`watermarking_config` object with `WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class."
|
|
)
|
|
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
|
self.watermarking_config.validate()
|
|
|
|
# 2. Validation of attribute combinations
|
|
# 2.1. detect sampling-only parameterization when not in sampling mode
|
|
if self.do_sample is False:
|
|
greedy_wrong_parameter_msg = (
|
|
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
|
|
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
|
|
)
|
|
if self.temperature is not None and self.temperature != 1.0:
|
|
minor_issues["temperature"] = greedy_wrong_parameter_msg.format(
|
|
flag_name="temperature", flag_value=self.temperature
|
|
)
|
|
if self.top_p is not None and self.top_p != 1.0:
|
|
minor_issues["top_p"] = greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p)
|
|
if self.min_p is not None:
|
|
minor_issues["min_p"] = greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p)
|
|
if self.typical_p is not None and self.typical_p != 1.0:
|
|
minor_issues["typical_p"] = greedy_wrong_parameter_msg.format(
|
|
flag_name="typical_p", flag_value=self.typical_p
|
|
)
|
|
if (
|
|
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
|
|
): # contrastive search uses top_k
|
|
minor_issues["top_k"] = greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k)
|
|
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
|
|
minor_issues["epsilon_cutoff"] = greedy_wrong_parameter_msg.format(
|
|
flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff
|
|
)
|
|
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
|
|
minor_issues["eta_cutoff"] = greedy_wrong_parameter_msg.format(
|
|
flag_name="eta_cutoff", flag_value=self.eta_cutoff
|
|
)
|
|
|
|
# 2.2. detect beam-only parameterization when not in beam mode
|
|
if self.num_beams == 1:
|
|
single_beam_wrong_parameter_msg = (
|
|
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
|
|
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`."
|
|
)
|
|
if self.early_stopping is not False:
|
|
minor_issues["early_stopping"] = single_beam_wrong_parameter_msg.format(
|
|
flag_name="early_stopping", flag_value=self.early_stopping
|
|
)
|
|
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
|
minor_issues["num_beam_groups"] = single_beam_wrong_parameter_msg.format(
|
|
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
|
)
|
|
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
|
|
minor_issues["diversity_penalty"] = single_beam_wrong_parameter_msg.format(
|
|
flag_name="diversity_penalty", flag_value=self.diversity_penalty
|
|
)
|
|
if self.length_penalty is not None and self.length_penalty != 1.0:
|
|
minor_issues["length_penalty"] = single_beam_wrong_parameter_msg.format(
|
|
flag_name="length_penalty", flag_value=self.length_penalty
|
|
)
|
|
if self.constraints is not None:
|
|
minor_issues["constraints"] = single_beam_wrong_parameter_msg.format(
|
|
flag_name="constraints", flag_value=self.constraints
|
|
)
|
|
# DoLa generation needs num_beams == 1
|
|
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
|
|
minor_issues["repetition_penalty"] = (
|
|
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
|
|
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
|
|
"DoLa decoding is `repetition_penalty>=1.2`.",
|
|
)
|
|
|
|
# 2.3. detect incorrect parameterization specific to advanced beam modes
|
|
else:
|
|
# constrained beam search
|
|
if self.constraints is not None or self.force_words_ids is not None:
|
|
constrained_wrong_parameter_msg = (
|
|
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. "
|
|
"However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation "
|
|
"mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue."
|
|
)
|
|
if self.do_sample is True:
|
|
raise ValueError(
|
|
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
|
|
)
|
|
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
|
raise ValueError(
|
|
constrained_wrong_parameter_msg.format(
|
|
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
|
)
|
|
)
|
|
# group beam search
|
|
elif self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
|
|
group_error_prefix = (
|
|
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
|
|
"this generation mode, "
|
|
)
|
|
if self.do_sample is True:
|
|
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
|
|
if self.num_beams % self.num_beam_groups != 0:
|
|
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
|
|
if self.diversity_penalty == 0.0:
|
|
raise ValueError(
|
|
group_error_prefix
|
|
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
|
|
)
|
|
|
|
# 2.4. check `num_return_sequences`
|
|
if self.num_return_sequences != 1:
|
|
if self.num_beams == 1:
|
|
if self.do_sample is False:
|
|
raise ValueError(
|
|
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
|
|
f"(got {self.num_return_sequences})."
|
|
)
|
|
elif self.num_return_sequences > self.num_beams:
|
|
raise ValueError(
|
|
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
|
|
f"({self.num_beams})."
|
|
)
|
|
|
|
# 2.5. check cache-related arguments
|
|
if self.use_cache is False:
|
|
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
|
|
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
|
|
# (otherwise a user might need to overwrite several parameters).
|
|
no_cache_warning = (
|
|
"You have set `use_cache` to `False`, but {cache_arg} is set to {cache_arg_value}. {cache_arg} will "
|
|
"have no effect."
|
|
)
|
|
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
|
|
if getattr(self, arg_name) is not None:
|
|
minor_issues[arg_name] = no_cache_warning.format(
|
|
cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)
|
|
)
|
|
|
|
# 2.6. other incorrect combinations
|
|
if self.return_dict_in_generate is not True:
|
|
for extra_output_flag in self.extra_output_flags:
|
|
if getattr(self, extra_output_flag) is True:
|
|
minor_issues[extra_output_flag] = (
|
|
f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When "
|
|
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored."
|
|
)
|
|
|
|
# 3. Check common issue: passing `generate` arguments inside the generation config
|
|
generate_arguments = (
|
|
"logits_processor",
|
|
"stopping_criteria",
|
|
"prefix_allowed_tokens_fn",
|
|
"synced_gpus",
|
|
"assistant_model",
|
|
"streamer",
|
|
"negative_prompt_ids",
|
|
"negative_prompt_attention_mask",
|
|
"use_model_defaults",
|
|
)
|
|
for arg in generate_arguments:
|
|
if hasattr(self, arg):
|
|
raise ValueError(
|
|
f"Argument `{arg}` is not a valid argument of `GenerationConfig`. It should be passed to "
|
|
"`generate()` (or a pipeline) directly."
|
|
)
|
|
|
|
# Finally, handle caught minor issues. With default parameterization, we will throw a minimal warning.
|
|
if len(minor_issues) > 0:
|
|
# Full list of issues with potential fixes
|
|
info_message = []
|
|
for attribute_name, issue_description in minor_issues.items():
|
|
info_message.append(f"- `{attribute_name}`: {issue_description}")
|
|
info_message = "\n".join(info_message)
|
|
info_message += (
|
|
"\nIf you're using a pretrained model, note that some of these attributes may be set through the "
|
|
"model's `generation_config.json` file."
|
|
)
|
|
|
|
if strict:
|
|
raise ValueError("GenerationConfig is invalid: \n" + info_message)
|
|
else:
|
|
attributes_with_issues = list(minor_issues.keys())
|
|
warning_message = (
|
|
f"The following generation flags are not valid and may be ignored: {attributes_with_issues}."
|
|
)
|
|
if logging.get_verbosity() >= logging.WARNING:
|
|
warning_message += " Set `TRANSFORMERS_VERBOSITY=info` for more details."
|
|
logger.warning(warning_message)
|
|
logger.info(info_message)
|
|
|
|
def save_pretrained(
|
|
self,
|
|
save_directory: Union[str, os.PathLike],
|
|
config_file_name: Optional[Union[str, os.PathLike]] = None,
|
|
push_to_hub: bool = False,
|
|
**kwargs,
|
|
):
|
|
r"""
|
|
Save a generation configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
|
[`~GenerationConfig.from_pretrained`] class method.
|
|
|
|
Args:
|
|
save_directory (`str` or `os.PathLike`):
|
|
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
|
config_file_name (`str` or `os.PathLike`, *optional*, defaults to `"generation_config.json"`):
|
|
Name of the generation configuration JSON file to be saved in `save_directory`.
|
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
|
namespace).
|
|
kwargs (`dict[str, Any]`, *optional*):
|
|
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
|
"""
|
|
|
|
# At save time, validate the instance enforcing strictness -- if any warning/exception would be thrown, we
|
|
# refuse to save the instance.
|
|
# This strictness is enforced to prevent bad configurations from being saved and re-used.
|
|
try:
|
|
self.validate(strict=True)
|
|
except ValueError as exc:
|
|
raise ValueError(str(exc) + "\n\nFix these issues to save the configuration.")
|
|
|
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
|
|
|
if use_auth_token is not None:
|
|
warnings.warn(
|
|
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. "
|
|
"Please use `token` instead.",
|
|
FutureWarning,
|
|
)
|
|
if kwargs.get("token", None) is not None:
|
|
raise ValueError(
|
|
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
|
)
|
|
kwargs["token"] = use_auth_token
|
|
|
|
config_file_name = config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME
|
|
|
|
if os.path.isfile(save_directory):
|
|
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
|
|
|
os.makedirs(save_directory, exist_ok=True)
|
|
|
|
if push_to_hub:
|
|
commit_message = kwargs.pop("commit_message", None)
|
|
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
|
repo_id = self._create_repo(repo_id, **kwargs)
|
|
files_timestamps = self._get_files_timestamps(save_directory)
|
|
|
|
output_config_file = os.path.join(save_directory, config_file_name)
|
|
|
|
self.to_json_file(output_config_file, use_diff=True)
|
|
logger.info(f"Configuration saved in {output_config_file}")
|
|
|
|
if push_to_hub:
|
|
self._upload_modified_files(
|
|
save_directory,
|
|
repo_id,
|
|
files_timestamps,
|
|
commit_message=commit_message,
|
|
token=kwargs.get("token"),
|
|
)
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
pretrained_model_name: Union[str, os.PathLike],
|
|
config_file_name: Optional[Union[str, os.PathLike]] = None,
|
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
|
force_download: bool = False,
|
|
local_files_only: bool = False,
|
|
token: Optional[Union[str, bool]] = None,
|
|
revision: str = "main",
|
|
**kwargs,
|
|
) -> "GenerationConfig":
|
|
r"""
|
|
Instantiate a [`GenerationConfig`] from a generation configuration file.
|
|
|
|
Args:
|
|
pretrained_model_name (`str` or `os.PathLike`):
|
|
This can be either:
|
|
|
|
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
|
huggingface.co.
|
|
- a path to a *directory* containing a configuration file saved using the
|
|
[`~GenerationConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
|
|
config_file_name (`str` or `os.PathLike`, *optional*, defaults to `"generation_config.json"`):
|
|
Name of the generation configuration JSON file to be loaded from `pretrained_model_name`.
|
|
cache_dir (`str` or `os.PathLike`, *optional*):
|
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
|
standard cache should not be used.
|
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to force to (re-)download the configuration files and override the cached versions if
|
|
they exist.
|
|
resume_download:
|
|
Deprecated and ignored. All downloads are now resumed by default when possible.
|
|
Will be removed in v5 of Transformers.
|
|
proxies (`dict[str, str]`, *optional*):
|
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
|
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
|
token (`str` or `bool`, *optional*):
|
|
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
|
the token generated when running `hf auth login` (stored in `~/.huggingface`).
|
|
revision (`str`, *optional*, defaults to `"main"`):
|
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
|
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
|
identifier allowed by git.
|
|
|
|
<Tip>
|
|
|
|
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
|
|
|
|
</Tip>
|
|
|
|
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
|
If `False`, then this function returns just the final configuration object.
|
|
|
|
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
|
|
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
|
|
part of `kwargs` which has not been used to update `config` and is otherwise ignored.
|
|
subfolder (`str`, *optional*, defaults to `""`):
|
|
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
|
specify the folder name here.
|
|
kwargs (`dict[str, Any]`, *optional*):
|
|
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
|
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
|
by the `return_unused_kwargs` keyword parameter.
|
|
|
|
Returns:
|
|
[`GenerationConfig`]: The configuration object instantiated from this pretrained model.
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import GenerationConfig
|
|
|
|
>>> # Download configuration from huggingface.co and cache.
|
|
>>> generation_config = GenerationConfig.from_pretrained("openai-community/gpt2")
|
|
|
|
>>> # E.g. config was saved using *save_pretrained('./test/saved_model/')*
|
|
>>> generation_config.save_pretrained("./test/saved_model/")
|
|
>>> generation_config = GenerationConfig.from_pretrained("./test/saved_model/")
|
|
|
|
>>> # You can also specify configuration names to your generation configuration file
|
|
>>> generation_config.save_pretrained("./test/saved_model/", config_file_name="my_configuration.json")
|
|
>>> generation_config = GenerationConfig.from_pretrained("./test/saved_model/", "my_configuration.json")
|
|
|
|
>>> # If you'd like to try a minor variation to an existing configuration, you can also pass generation
|
|
>>> # arguments to `.from_pretrained()`. Be mindful that typos and unused arguments will be ignored
|
|
>>> generation_config, unused_kwargs = GenerationConfig.from_pretrained(
|
|
... "openai-community/gpt2", top_k=1, foo=False, do_sample=True, return_unused_kwargs=True
|
|
... )
|
|
>>> generation_config.top_k
|
|
1
|
|
|
|
>>> unused_kwargs
|
|
{'foo': False}
|
|
```"""
|
|
config_file_name = config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME
|
|
|
|
resume_download = kwargs.pop("resume_download", None)
|
|
proxies = kwargs.pop("proxies", None)
|
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
|
subfolder = kwargs.pop("subfolder", "")
|
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
|
from_auto_class = kwargs.pop("_from_auto", False)
|
|
commit_hash = kwargs.pop("_commit_hash", None)
|
|
|
|
if use_auth_token is not None:
|
|
warnings.warn(
|
|
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
|
FutureWarning,
|
|
)
|
|
if token is not None:
|
|
raise ValueError(
|
|
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
|
)
|
|
token = use_auth_token
|
|
|
|
user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
|
|
if from_pipeline is not None:
|
|
user_agent["using_pipeline"] = from_pipeline
|
|
|
|
config_path = os.path.join(pretrained_model_name, config_file_name)
|
|
config_path = str(config_path)
|
|
|
|
is_local = os.path.exists(config_path)
|
|
if os.path.isfile(os.path.join(subfolder, config_path)):
|
|
# Special case when config_path is a local file
|
|
resolved_config_file = config_path
|
|
is_local = True
|
|
elif is_remote_url(config_path):
|
|
configuration_file = config_path
|
|
resolved_config_file = download_url(config_path)
|
|
else:
|
|
configuration_file = config_file_name
|
|
try:
|
|
# Load from local folder or from cache or download from model Hub and cache
|
|
resolved_config_file = cached_file(
|
|
pretrained_model_name,
|
|
configuration_file,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
resume_download=resume_download,
|
|
local_files_only=local_files_only,
|
|
token=token,
|
|
user_agent=user_agent,
|
|
revision=revision,
|
|
subfolder=subfolder,
|
|
_commit_hash=commit_hash,
|
|
)
|
|
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
|
except OSError:
|
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
|
# the original exception.
|
|
raise
|
|
except Exception:
|
|
# For any other exception, we throw a generic error.
|
|
raise OSError(
|
|
f"Can't load the configuration of '{pretrained_model_name}'. If you were trying to load it"
|
|
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
|
|
f" name. Otherwise, make sure '{pretrained_model_name}' is the correct path to a directory"
|
|
f" containing a {configuration_file} file"
|
|
)
|
|
|
|
try:
|
|
# Load config dict
|
|
config_dict = cls._dict_from_json_file(resolved_config_file)
|
|
config_dict["_commit_hash"] = commit_hash
|
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
raise OSError(f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file.")
|
|
|
|
if is_local:
|
|
logger.info(f"loading configuration file {resolved_config_file}")
|
|
else:
|
|
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
|
|
|
|
if kwargs.get("return_unused_kwargs") is True:
|
|
config, unused_kwargs = cls.from_dict(config_dict, **kwargs)
|
|
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
|
|
return config, unused_kwargs
|
|
else:
|
|
config = cls.from_dict(config_dict, **kwargs)
|
|
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
|
|
return config
|
|
|
|
@classmethod
|
|
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
|
with open(json_file, "r", encoding="utf-8") as reader:
|
|
text = reader.read()
|
|
return json.loads(text)
|
|
|
|
@classmethod
|
|
def from_dict(cls, config_dict: dict[str, Any], **kwargs) -> "GenerationConfig":
|
|
"""
|
|
Instantiates a [`GenerationConfig`] from a Python dictionary of parameters.
|
|
|
|
Args:
|
|
config_dict (`dict[str, Any]`):
|
|
Dictionary that will be used to instantiate the configuration object.
|
|
kwargs (`dict[str, Any]`):
|
|
Additional parameters from which to initialize the configuration object.
|
|
|
|
Returns:
|
|
[`GenerationConfig`]: The configuration object instantiated from those parameters.
|
|
"""
|
|
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
|
# Those arguments may be passed along for our internal telemetry.
|
|
# We remove them so they don't appear in `return_unused_kwargs`.
|
|
kwargs.pop("_from_auto", None)
|
|
kwargs.pop("_from_pipeline", None)
|
|
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
|
|
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
|
|
kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
|
|
|
# The line below allows model-specific config to be loaded as well through kwargs, with safety checks.
|
|
# See https://github.com/huggingface/transformers/pull/21269
|
|
config = cls(**{**config_dict, **kwargs})
|
|
unused_kwargs = config.update(**kwargs)
|
|
|
|
logger.info(f"Generate config {config}")
|
|
if return_unused_kwargs:
|
|
return config, unused_kwargs
|
|
else:
|
|
return config
|
|
|
|
def dict_torch_dtype_to_str(self, d: dict[str, Any]) -> None:
|
|
"""
|
|
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
|
|
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
|
|
string, which can then be stored in the json format.
|
|
"""
|
|
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
|
|
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
|
for value in d.values():
|
|
if isinstance(value, dict):
|
|
self.dict_torch_dtype_to_str(value)
|
|
|
|
def to_diff_dict(self) -> dict[str, Any]:
|
|
"""
|
|
Removes all attributes from config which correspond to the default config attributes for better readability and
|
|
serializes to a Python dictionary.
|
|
|
|
Returns:
|
|
`dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
|
|
"""
|
|
config_dict = self.to_dict()
|
|
|
|
# get the default config dict
|
|
default_config_dict = GenerationConfig().to_dict()
|
|
|
|
serializable_config_dict = {}
|
|
|
|
# only serialize values that differ from the default config
|
|
for key, value in config_dict.items():
|
|
if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]:
|
|
serializable_config_dict[key] = value
|
|
|
|
self.dict_torch_dtype_to_str(serializable_config_dict)
|
|
return serializable_config_dict
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""
|
|
Serializes this instance to a Python dictionary.
|
|
|
|
Returns:
|
|
`dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
|
"""
|
|
output = copy.deepcopy(self.__dict__)
|
|
|
|
# Fields to ignore at serialization time
|
|
if "_commit_hash" in output:
|
|
del output["_commit_hash"]
|
|
if "_original_object_hash" in output:
|
|
del output["_original_object_hash"]
|
|
if "compile_config" in output:
|
|
del output["compile_config"]
|
|
|
|
# Transformers version when serializing this file
|
|
output["transformers_version"] = __version__
|
|
|
|
self.dict_torch_dtype_to_str(output)
|
|
return output
|
|
|
|
def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str:
|
|
"""
|
|
Serializes this instance to a JSON string.
|
|
|
|
Args:
|
|
use_diff (`bool`, *optional*, defaults to `True`):
|
|
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
|
|
is serialized to JSON string.
|
|
ignore_metadata (`bool`, *optional*, defaults to `False`):
|
|
Whether to ignore the metadata fields present in the instance
|
|
|
|
Returns:
|
|
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
|
"""
|
|
if use_diff is True:
|
|
config_dict = self.to_diff_dict()
|
|
else:
|
|
config_dict = self.to_dict()
|
|
|
|
if ignore_metadata:
|
|
for metadata_field in METADATA_FIELDS:
|
|
config_dict.pop(metadata_field, None)
|
|
|
|
def convert_keys_to_string(obj):
|
|
if isinstance(obj, dict):
|
|
return {str(key): convert_keys_to_string(value) for key, value in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return [convert_keys_to_string(item) for item in obj]
|
|
else:
|
|
return obj
|
|
|
|
def convert_dataclass_to_dict(obj):
|
|
if isinstance(obj, dict):
|
|
return {key: convert_dataclass_to_dict(value) for key, value in obj.items()}
|
|
elif is_dataclass(obj):
|
|
return obj.to_dict()
|
|
else:
|
|
return obj
|
|
|
|
config_dict = convert_keys_to_string(config_dict)
|
|
config_dict = convert_dataclass_to_dict(config_dict)
|
|
|
|
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
|
|
|
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
|
|
"""
|
|
Save this instance to a JSON file.
|
|
|
|
Args:
|
|
json_file_path (`str` or `os.PathLike`):
|
|
Path to the JSON file in which this configuration instance's parameters will be saved.
|
|
use_diff (`bool`, *optional*, defaults to `True`):
|
|
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
|
|
is serialized to JSON file.
|
|
"""
|
|
with open(json_file_path, "w", encoding="utf-8") as writer:
|
|
writer.write(self.to_json_string(use_diff=use_diff))
|
|
|
|
@classmethod
|
|
def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig":
|
|
"""
|
|
Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy
|
|
[`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
|
|
|
|
Args:
|
|
model_config (`PretrainedConfig`):
|
|
The model config that will be used to instantiate the generation config.
|
|
|
|
Returns:
|
|
[`GenerationConfig`]: The configuration object instantiated from those parameters.
|
|
"""
|
|
config_dict = model_config.to_dict()
|
|
config_dict.pop("_from_model_config", None)
|
|
|
|
# Removes all `None` from the model config dict -- this lets the generation config defaults to take hold
|
|
config_dict = {key: value for key, value in config_dict.items() if value is not None}
|
|
|
|
generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
|
|
|
|
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
|
# generation config (which in turn is defined from the outer attributes of model config).
|
|
decoder_config = model_config.get_text_config(decoder=True)
|
|
if decoder_config is not model_config:
|
|
default_generation_config = GenerationConfig()
|
|
decoder_config_dict = decoder_config.to_dict()
|
|
for attr in generation_config.to_dict().keys():
|
|
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
|
|
if attr in decoder_config_dict and is_unset:
|
|
setattr(generation_config, attr, decoder_config_dict[attr])
|
|
|
|
# If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
|
|
if generation_config.return_dict_in_generate is False:
|
|
if any(
|
|
getattr(generation_config, extra_output_flag, False)
|
|
for extra_output_flag in generation_config.extra_output_flags
|
|
):
|
|
generation_config.return_dict_in_generate = True
|
|
|
|
# Hash to detect whether the instance was modified
|
|
generation_config._original_object_hash = hash(generation_config)
|
|
return generation_config
|
|
|
|
def update(self, **kwargs):
|
|
"""
|
|
Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
|
|
returning all the unused kwargs.
|
|
|
|
Args:
|
|
kwargs (`dict[str, Any]`):
|
|
Dictionary of attributes to tentatively update this class.
|
|
|
|
Returns:
|
|
`dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
|
|
"""
|
|
to_remove = []
|
|
for key, value in kwargs.items():
|
|
if hasattr(self, key):
|
|
setattr(self, key, value)
|
|
to_remove.append(key)
|
|
|
|
# Confirm that the updated instance is still valid
|
|
self.validate()
|
|
|
|
# Remove all the attributes that were updated, without modifying the input dict
|
|
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
|
return unused_kwargs
|
|
|
|
|
|
@dataclass
|
|
class BaseWatermarkingConfig(ABC):
|
|
"""Generic watermarking config"""
|
|
|
|
@classmethod
|
|
def from_dict(cls, config_dict, **kwargs):
|
|
"""
|
|
Constructs a BaseWatermarkingConfig instance from a dictionary of parameters.
|
|
|
|
Args:
|
|
config_dict (dict[str, Any]): Dictionary containing configuration parameters.
|
|
**kwargs: Additional keyword arguments to override dictionary values.
|
|
|
|
Returns:
|
|
BaseWatermarkingConfig: Instance of BaseWatermarkingConfig constructed from the dictionary.
|
|
"""
|
|
config = cls(**config_dict)
|
|
to_remove = []
|
|
for key, value in kwargs.items():
|
|
if hasattr(config, key):
|
|
setattr(config, key, value)
|
|
to_remove.append(key)
|
|
for key in to_remove:
|
|
kwargs.pop(key, None)
|
|
return config
|
|
|
|
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
|
"""
|
|
Save this instance to a JSON file.
|
|
|
|
Args:
|
|
json_file_path (Union[str, os.PathLike]): Path to the JSON file in which this configuration instance's parameters will be saved.
|
|
"""
|
|
with open(json_file_path, "w", encoding="utf-8") as writer:
|
|
config_dict = self.to_dict()
|
|
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
|
|
|
writer.write(json_string)
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""
|
|
Serializes this instance to a Python dictionary.
|
|
|
|
Returns:
|
|
dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
|
|
"""
|
|
output = copy.deepcopy(self.__dict__)
|
|
return output
|
|
|
|
def __iter__(self):
|
|
for attr, value in copy.deepcopy(self.__dict__).items():
|
|
yield attr, value
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__} {self.to_json_string()}"
|
|
|
|
def to_json_string(self):
|
|
"""
|
|
Serializes this instance to a JSON formatted string.
|
|
|
|
Returns:
|
|
str: JSON formatted string representing the configuration instance.
|
|
"""
|
|
return json.dumps(self.__dict__, indent=2) + "\n"
|
|
|
|
def update(self, **kwargs):
|
|
"""
|
|
Update the configuration attributes with new values.
|
|
|
|
Args:
|
|
**kwargs: Keyword arguments representing configuration attributes and their new values.
|
|
"""
|
|
for key, value in kwargs.items():
|
|
if hasattr(self, key):
|
|
setattr(self, key, value)
|
|
|
|
@abstractmethod
|
|
def validate(self): ...
|
|
|
|
@abstractmethod
|
|
def construct_processor(self, vocab_size): ...
|
|
|
|
|
|
@dataclass
|
|
class WatermarkingConfig(BaseWatermarkingConfig):
|
|
"""
|
|
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
|
|
See [this paper](https://huggingface.co/papers/2306.04634) for more details on the arguments.
|
|
|
|
Accepts the following keys:
|
|
- greenlist_ratio (`float`):
|
|
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
|
|
- bias (`float`):
|
|
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
|
|
- hashing_key (`int`):
|
|
Hashing key used for watermarking. Defaults to 15485863 (the millionth prime).
|
|
- seeding_scheme (`str`):
|
|
Algorithm to use for watermarking. Accepts values:
|
|
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
|
|
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
|
|
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
|
|
- context_width(`int`):
|
|
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
greenlist_ratio: Optional[float] = 0.25,
|
|
bias: Optional[float] = 2.0,
|
|
hashing_key: Optional[int] = 15485863,
|
|
seeding_scheme: Optional[str] = "lefthash",
|
|
context_width: Optional[int] = 1,
|
|
):
|
|
self.greenlist_ratio = greenlist_ratio
|
|
self.bias = bias
|
|
self.hashing_key = hashing_key
|
|
self.seeding_scheme = seeding_scheme
|
|
self.context_width = context_width
|
|
|
|
def validate(self):
|
|
watermark_missing_arg_msg = (
|
|
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
|
|
"but found {found_value}"
|
|
)
|
|
if self.seeding_scheme not in ["selfhash", "lefthash"]:
|
|
raise ValueError(
|
|
watermark_missing_arg_msg.format(
|
|
key="seeding_scheme",
|
|
correct_value="[`selfhash`, `lefthash`]",
|
|
found_value=self.seeding_scheme,
|
|
),
|
|
)
|
|
if not 0.0 <= self.greenlist_ratio <= 1.0:
|
|
raise ValueError(
|
|
watermark_missing_arg_msg.format(
|
|
key="greenlist_ratio",
|
|
correct_value="in range between 0.0 and 1.0",
|
|
found_value=self.seeding_scheme,
|
|
),
|
|
)
|
|
if not self.context_width >= 1:
|
|
raise ValueError(
|
|
watermark_missing_arg_msg.format(
|
|
key="context_width",
|
|
correct_value="a positive integer",
|
|
found_value=self.context_width,
|
|
),
|
|
)
|
|
|
|
def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor":
|
|
return WatermarkLogitsProcessor(
|
|
vocab_size=vocab_size,
|
|
device=device,
|
|
greenlist_ratio=self.greenlist_ratio,
|
|
bias=self.bias,
|
|
hashing_key=self.hashing_key,
|
|
seeding_scheme=self.seeding_scheme,
|
|
context_width=self.context_width,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig):
|
|
"""
|
|
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
|
|
See [this paper](https://www.nature.com/articles/s41586-024-08025-4) for more details on the arguments.
|
|
|
|
Args:
|
|
ngram_len (`int`):
|
|
Ngram length.
|
|
keys (`list[int]`):
|
|
A sequence of watermarking keys, one for each depth.
|
|
context_history_size (`int`, *optional*, defaults to 1024):
|
|
Size of the tensor to keep track of seen contexts.
|
|
sampling_table_seed (`int`, *optional*, defaults to 0):
|
|
Random seed to generate the sampling table.
|
|
sampling_table_size (`int`, *optional*, defaults to 65536):
|
|
Size of the sampling table.
|
|
skip_first_ngram_calls (`bool`, *optional*, defaults to `False`):
|
|
Whether to skip first ngram calls.
|
|
debug_mode (`bool`, optional, *optional*, defaults to `False`):
|
|
Logits are modified to uniform one got before watermarking modification is applied. This is to test the
|
|
implementation.
|
|
|
|
Examples:
|
|
```python
|
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b', padding_side="left")
|
|
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b')
|
|
|
|
>>> # SynthID Text configuration
|
|
>>> watermarking_config = SynthIDTextWatermarkingConfig(
|
|
... keys=[654, 400, 836, 123, 340, 443, 597, 160, 57],
|
|
... ngram_len=5,
|
|
... )
|
|
|
|
>>> # Generation with watermarking
|
|
>>> tokenized_prompts = tokenizer(["Once upon a time, "], return_tensors="pt", padding=True)
|
|
>>> output_sequences = model.generate(
|
|
... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, max_new_tokens=10
|
|
... )
|
|
>>> watermarked_text = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
|
|
```
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
ngram_len: int,
|
|
keys: list[int],
|
|
context_history_size: int = 1024,
|
|
sampling_table_seed: int = 0,
|
|
sampling_table_size: int = 2**16,
|
|
skip_first_ngram_calls: bool = False,
|
|
debug_mode: bool = False,
|
|
):
|
|
self.ngram_len = ngram_len
|
|
self.keys = keys
|
|
self.sampling_table_size = sampling_table_size
|
|
self.sampling_table_seed = sampling_table_seed
|
|
self.context_history_size = context_history_size
|
|
self.skip_first_ngram_calls = skip_first_ngram_calls
|
|
self.debug_mode = debug_mode
|
|
|
|
def validate(self):
|
|
watermark_missing_arg_msg = (
|
|
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
|
|
"but found {found_value}"
|
|
)
|
|
if self.sampling_table_size > 2**24:
|
|
raise ValueError(
|
|
watermark_missing_arg_msg.format(
|
|
key="sampling_table_size",
|
|
correct_value="< 2**24",
|
|
found_value=self.sampling_table_size,
|
|
),
|
|
)
|
|
|
|
def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor":
|
|
return SynthIDTextWatermarkLogitsProcessor(
|
|
ngram_len=self.ngram_len,
|
|
keys=self.keys,
|
|
sampling_table_size=self.sampling_table_size,
|
|
sampling_table_seed=self.sampling_table_seed,
|
|
context_history_size=self.context_history_size,
|
|
device=device,
|
|
skip_first_ngram_calls=self.skip_first_ngram_calls,
|
|
debug_mode=self.debug_mode,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class CompileConfig:
|
|
"""
|
|
Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`.
|
|
See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments.
|
|
|
|
Args:
|
|
fullgraph (`bool`, *optional*, defaults to `True`):
|
|
If `True`, requires that the whole forward be capturable in a single graph.
|
|
dynamic (`bool` or `None`, *optional*):
|
|
Whether to try to use dynamic shape graphs.
|
|
backend (`str` or `Callable`, *optional*, defaults to `"inductor"`):
|
|
Backend to be used.
|
|
mode (`str`, *optional*, defaults to `"reduce-overhead"`):
|
|
Controls balance between performance and overhead.
|
|
options (`dict`, *optional*):
|
|
A dictionary of options to pass to the backend.
|
|
|
|
Examples:
|
|
```python
|
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b')
|
|
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b').cuda()
|
|
|
|
>>> # Automatic compile configuration, used with static cache
|
|
>>> compile_config = CompileConfig(dynamic=True)
|
|
|
|
>>> # Generation with static cache and compile config
|
|
>>> input = tokenizer.encode("Hello there, how", return_tensors="pt").cuda()
|
|
>>> output = model.generate(
|
|
... input, do_sample=False, max_new_tokens=300, cache_implementation="static", compile_config=compile_config
|
|
... )
|
|
>>> output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
|
|
```
|
|
"""
|
|
|
|
fullgraph: bool = True
|
|
dynamic: Optional[bool] = None
|
|
backend: Union[str, Callable] = "inductor"
|
|
mode: str = "reduce-overhead"
|
|
options: Optional[dict] = None
|
|
# Used to flag our `generate` call to compile on e.g. CPU. Often not optimal, but useful for testing purposes.
|
|
_compile_all_devices = None
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Serializes this instance to a Python dictionary."""
|
|
return copy.deepcopy({key: value for key, value in self.__dict__.items() if key != "_compile_all_devices"})
|