476 lines
20 KiB
Python
476 lines
20 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import TYPE_CHECKING, Any, Optional
|
|
|
|
import sentencepiece as spm
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ...cache_utils import Cache, DynamicCache
|
|
from ...configuration_utils import PretrainedConfig
|
|
from ...masking_utils import create_causal_mask
|
|
from ...modeling_outputs import BaseModelOutputWithPast
|
|
from ...processing_utils import Unpack
|
|
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
|
from ...utils import TransformersKwargs, logging
|
|
from ..llama.modeling_llama import (
|
|
LlamaForCausalLM,
|
|
LlamaForSequenceClassification,
|
|
LlamaForTokenClassification,
|
|
LlamaMLP,
|
|
LlamaModel,
|
|
)
|
|
from ..llama.tokenization_llama import LlamaTokenizer
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from ...tokenization_utils_base import TextInput
|
|
|
|
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
|
|
|
SPIECE_UNDERLINE = "▁"
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class GemmaConfig(PretrainedConfig):
|
|
r"""
|
|
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
|
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
|
defaults will yield a similar configuration to that of the Gemma-7B.
|
|
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
|
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
documentation from [`PretrainedConfig`] for more information.
|
|
Args:
|
|
vocab_size (`int`, *optional*, defaults to 256000):
|
|
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
|
|
`inputs_ids` passed when calling [`GemmaModel`]
|
|
hidden_size (`int`, *optional*, defaults to 3072):
|
|
Dimension of the hidden representations.
|
|
intermediate_size (`int`, *optional*, defaults to 24576):
|
|
Dimension of the MLP representations.
|
|
num_hidden_layers (`int`, *optional*, defaults to 28):
|
|
Number of hidden layers in the Transformer decoder.
|
|
num_attention_heads (`int`, *optional*, defaults to 16):
|
|
Number of attention heads for each attention layer in the Transformer decoder.
|
|
num_key_value_heads (`int`, *optional*, defaults to 16):
|
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
|
by meanpooling all the original heads within that group. For more details, check out [this
|
|
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
|
`num_attention_heads`.
|
|
head_dim (`int`, *optional*, defaults to 256):
|
|
The attention head dimension.
|
|
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
|
The legacy activation function. It is overwritten by the `hidden_activation`.
|
|
hidden_activation (`str` or `function`, *optional*):
|
|
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
|
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
|
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
|
The maximum sequence length that this model might ever be used with.
|
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
|
The epsilon used by the rms normalization layers.
|
|
use_cache (`bool`, *optional*, defaults to `True`):
|
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
relevant if `config.is_decoder=True`.
|
|
pad_token_id (`int`, *optional*, defaults to 0):
|
|
Padding token id.
|
|
eos_token_id (`int`, *optional*, defaults to 1):
|
|
End of stream token id.
|
|
bos_token_id (`int`, *optional*, defaults to 2):
|
|
Beginning of stream token id.
|
|
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
|
Whether to tie weight embeddings
|
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
|
The base period of the RoPE embeddings.
|
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
The dropout ratio for the attention probabilities.
|
|
```python
|
|
>>> from transformers import GemmaModel, GemmaConfig
|
|
>>> # Initializing a Gemma gemma-7b style configuration
|
|
>>> configuration = GemmaConfig()
|
|
>>> # Initializing a model from the gemma-7b style configuration
|
|
>>> model = GemmaModel(configuration)
|
|
>>> # Accessing the model configuration
|
|
>>> configuration = model.config
|
|
```"""
|
|
|
|
model_type = "gemma"
|
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
base_model_tp_plan = {
|
|
"layers.*.self_attn.q_proj": "colwise",
|
|
"layers.*.self_attn.k_proj": "colwise",
|
|
"layers.*.self_attn.v_proj": "colwise",
|
|
"layers.*.self_attn.o_proj": "rowwise",
|
|
"layers.*.mlp.gate_proj": "colwise",
|
|
"layers.*.mlp.up_proj": "colwise",
|
|
"layers.*.mlp.down_proj": "rowwise",
|
|
}
|
|
base_model_pp_plan = {
|
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
|
"norm": (["hidden_states"], ["hidden_states"]),
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size=256000,
|
|
hidden_size=3072,
|
|
intermediate_size=24576,
|
|
num_hidden_layers=28,
|
|
num_attention_heads=16,
|
|
num_key_value_heads=16,
|
|
head_dim=256,
|
|
hidden_act="gelu_pytorch_tanh",
|
|
hidden_activation=None,
|
|
max_position_embeddings=8192,
|
|
initializer_range=0.02,
|
|
rms_norm_eps=1e-6,
|
|
use_cache=True,
|
|
pad_token_id=0,
|
|
eos_token_id=1,
|
|
bos_token_id=2,
|
|
tie_word_embeddings=True,
|
|
rope_theta=10000.0,
|
|
attention_bias=False,
|
|
attention_dropout=0.0,
|
|
**kwargs,
|
|
):
|
|
self.vocab_size = vocab_size
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.head_dim = head_dim
|
|
self.num_key_value_heads = num_key_value_heads
|
|
self.hidden_act = hidden_act
|
|
self.hidden_activation = hidden_activation
|
|
self.initializer_range = initializer_range
|
|
self.rms_norm_eps = rms_norm_eps
|
|
self.use_cache = use_cache
|
|
self.rope_theta = rope_theta
|
|
self.attention_bias = attention_bias
|
|
self.attention_dropout = attention_dropout
|
|
|
|
super().__init__(
|
|
pad_token_id=pad_token_id,
|
|
bos_token_id=bos_token_id,
|
|
eos_token_id=eos_token_id,
|
|
tie_word_embeddings=tie_word_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class GemmaTokenizer(LlamaTokenizer, PreTrainedTokenizer):
|
|
"""
|
|
Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
|
|
no padding token in the original model.
|
|
|
|
Args:
|
|
vocab_file (`str`):
|
|
Path to the vocabulary file.
|
|
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
|
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
|
token instead.
|
|
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<bos>"`):
|
|
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
|
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<eos>"`):
|
|
The end of sequence token.
|
|
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
|
|
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
|
|
attention mechanisms or loss computation.
|
|
sp_model_kwargs (`dict[str, Any]`, `Optional`, *optional*):
|
|
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
|
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
|
to set:
|
|
|
|
- `enable_sampling`: Enable subword regularization.
|
|
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
|
|
|
- `nbest_size = {0,1}`: No sampling is performed.
|
|
- `nbest_size > 1`: samples from the nbest_size results.
|
|
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
|
using forward-filtering-and-backward-sampling algorithm.
|
|
|
|
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
|
BPE-dropout.
|
|
|
|
add_bos_token (`bool`, *optional*, defaults to `True`):
|
|
Whether or not to add an `bos_token` at the start of sequences.
|
|
add_eos_token (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to add an `eos_token` at the end of sequences.
|
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
|
|
extra spaces.
|
|
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
|
|
Whether or not the default system prompt for Gemma should be used.
|
|
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to add spaces between special tokens.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_file,
|
|
unk_token="<unk>",
|
|
bos_token="<bos>",
|
|
eos_token="<eos>",
|
|
pad_token="<pad>",
|
|
sp_model_kwargs: Optional[dict[str, Any]] = None,
|
|
add_bos_token=True,
|
|
add_eos_token=False,
|
|
clean_up_tokenization_spaces=False,
|
|
use_default_system_prompt=False,
|
|
spaces_between_special_tokens=False,
|
|
**kwargs,
|
|
):
|
|
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
|
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
|
|
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
|
|
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
|
|
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
|
|
|
|
self.vocab_file = vocab_file
|
|
self.add_bos_token = add_bos_token
|
|
self.add_eos_token = add_eos_token
|
|
self.use_default_system_prompt = use_default_system_prompt
|
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
|
self.sp_model.Load(vocab_file)
|
|
|
|
PreTrainedTokenizer.__init__(
|
|
self,
|
|
bos_token=bos_token,
|
|
eos_token=eos_token,
|
|
unk_token=unk_token,
|
|
pad_token=pad_token,
|
|
add_bos_token=add_bos_token,
|
|
add_eos_token=add_eos_token,
|
|
sp_model_kwargs=sp_model_kwargs,
|
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
|
use_default_system_prompt=use_default_system_prompt,
|
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
|
**kwargs,
|
|
)
|
|
|
|
def get_spm_processor(self):
|
|
raise AttributeError("Not needed for Gemma")
|
|
|
|
def unk_token_length(self):
|
|
raise AttributeError("Not needed for Gemma")
|
|
|
|
def tokenize(self, text: "TextInput", **kwargs) -> list[str]:
|
|
"""
|
|
Args:
|
|
text: TextInput
|
|
Simply calls PreTrainedTokenizer's method
|
|
"""
|
|
return PreTrainedTokenizer.tokenize(self, text, **kwargs)
|
|
|
|
def _tokenize(self, text, **kwargs):
|
|
"""
|
|
Args:
|
|
text: TextInput
|
|
Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
|
|
"""
|
|
return self.sp_model.encode(text, out_type=str)
|
|
|
|
def _decode(
|
|
self,
|
|
token_ids: list[int],
|
|
skip_special_tokens: bool = False,
|
|
spaces_between_special_tokens: bool = False,
|
|
**kwargs,
|
|
) -> str:
|
|
sub_texts = []
|
|
current_sub_text = []
|
|
for ids in token_ids:
|
|
if skip_special_tokens and ids in self.all_special_ids:
|
|
continue
|
|
if ids in self._added_tokens_decoder:
|
|
if current_sub_text:
|
|
sub_texts.append(self.sp_model.decode(current_sub_text))
|
|
sub_texts.append(self._added_tokens_decoder[ids].content)
|
|
current_sub_text = []
|
|
else:
|
|
current_sub_text.append(ids)
|
|
if current_sub_text:
|
|
sub_texts.append(self.sp_model.decode(current_sub_text))
|
|
|
|
if spaces_between_special_tokens:
|
|
sub_texts = " ".join(sub_texts)
|
|
else:
|
|
sub_texts = "".join(sub_texts)
|
|
|
|
return sub_texts.replace(SPIECE_UNDERLINE, " ")
|
|
|
|
def convert_tokens_to_string(self, tokens):
|
|
"""Converts a sequence of tokens (string) in a single string."""
|
|
current_sub_tokens = []
|
|
out_string = ""
|
|
for token in tokens:
|
|
# make sure that special tokens are not decoded using sentencepiece model
|
|
if token in self._added_tokens_encoder:
|
|
out_string += self.sp_model.decode(current_sub_tokens) + token
|
|
current_sub_tokens = []
|
|
else:
|
|
current_sub_tokens.append(token)
|
|
out_string += self.sp_model.decode(current_sub_tokens)
|
|
return out_string
|
|
|
|
|
|
class GemmaRMSNorm(nn.Module):
|
|
def __init__(self, dim: int, eps: float = 1e-6):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = nn.Parameter(torch.zeros(dim))
|
|
|
|
def _norm(self, x):
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
def forward(self, x):
|
|
output = self._norm(x.float())
|
|
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
|
# See https://github.com/huggingface/transformers/pull/29402
|
|
output = output * (1.0 + self.weight.float())
|
|
return output.type_as(x)
|
|
|
|
def extra_repr(self):
|
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
|
|
|
|
|
class GemmaMLP(LlamaMLP):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
|
|
|
|
class GemmaModel(LlamaModel):
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Cache] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> BaseModelOutputWithPast:
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
if use_cache and past_key_values is None:
|
|
past_key_values = DynamicCache()
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
)
|
|
|
|
if position_ids is None:
|
|
position_ids = cache_position.unsqueeze(0)
|
|
|
|
causal_mask = create_causal_mask(
|
|
config=self.config,
|
|
input_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
position_ids=position_ids,
|
|
)
|
|
|
|
# embed positions
|
|
hidden_states = inputs_embeds
|
|
|
|
# create position embeddings to be shared across the decoder layers
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
# normalized
|
|
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
|
# See https://github.com/huggingface/transformers/pull/29402
|
|
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
|
hidden_states = hidden_states * normalizer
|
|
|
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
hidden_states = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_embeddings=position_embeddings,
|
|
**kwargs,
|
|
)
|
|
hidden_states = self.norm(hidden_states)
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=past_key_values if use_cache else None,
|
|
)
|
|
|
|
|
|
class GemmaForCausalLM(LlamaForCausalLM):
|
|
def forward(**super_kwargs):
|
|
r"""
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
|
|
|
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
|
|
|
>>> prompt = "What is your favorite condiment?"
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"What is your favorite condiment?"
|
|
```"""
|
|
return super().forward(**super_kwargs)
|
|
|
|
|
|
class GemmaForSequenceClassification(LlamaForSequenceClassification):
|
|
pass
|
|
|
|
|
|
class GemmaForTokenClassification(LlamaForTokenClassification):
|
|
pass
|
|
|
|
|
|
__all__ = [
|
|
"GemmaConfig",
|
|
"GemmaTokenizer",
|
|
"GemmaModel",
|
|
"GemmaForCausalLM",
|
|
"GemmaForSequenceClassification",
|
|
"GemmaForTokenClassification",
|
|
"GemmaPreTrainedModel", # noqa: F822
|
|
]
|