237 lines
9.5 KiB
Python
237 lines
9.5 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch Starcoder2 model."""
|
|
|
|
from typing import Callable, Optional, Union
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
|
|
from transformers.utils.generic import check_model_inputs
|
|
|
|
from ...activations import ACT2FN
|
|
from ...cache_utils import Cache, DynamicCache
|
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
from ...modeling_outputs import BaseModelOutputWithPast
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
from ...processing_utils import Unpack
|
|
from ...utils import TransformersKwargs, logging
|
|
from ..mistral.modeling_mistral import (
|
|
MistralAttention,
|
|
MistralDecoderLayer,
|
|
MistralForCausalLM,
|
|
MistralForSequenceClassification,
|
|
MistralForTokenClassification,
|
|
MistralModel,
|
|
MistralRotaryEmbedding,
|
|
apply_rotary_pos_emb,
|
|
eager_attention_forward,
|
|
)
|
|
from .configuration_starcoder2 import Starcoder2Config
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class Starcoder2MLP(nn.Module):
|
|
def __init__(self, config: Starcoder2Config):
|
|
super().__init__()
|
|
embed_dim = config.hidden_size
|
|
self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias)
|
|
self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias)
|
|
self.act = ACT2FN[config.hidden_act]
|
|
self.residual_dropout = config.residual_dropout
|
|
|
|
def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
|
hidden_states = self.c_fc(hidden_states)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states = self.c_proj(hidden_states)
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training)
|
|
return hidden_states
|
|
|
|
|
|
class Starcoder2Attention(MistralAttention):
|
|
def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None):
|
|
super().__init__()
|
|
self.residual_dropout = config.residual_dropout
|
|
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias)
|
|
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
|
|
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
|
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor],
|
|
past_key_value: Optional[Cache] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
input_shape = hidden_states.shape[:-1]
|
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
if past_key_value is not None:
|
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
if self.config._attn_implementation != "eager":
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scaling,
|
|
sliding_window=getattr(self.config, "sliding_window", None), # diff with Llama
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
attn_output = nn.functional.dropout(
|
|
attn_output, p=self.residual_dropout, training=self.training
|
|
) # diff with Llama
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class Starcoder2DecoderLayer(MistralDecoderLayer):
|
|
def __init__(self, config: Starcoder2Config, layer_idx: int):
|
|
super().__init__(self)
|
|
self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx)
|
|
self.mlp = Starcoder2MLP(config)
|
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
|
|
|
|
|
class Starcoder2RotaryEmbedding(MistralRotaryEmbedding):
|
|
pass
|
|
|
|
|
|
class Starcoder2Model(MistralModel):
|
|
def __init__(self, config: Starcoder2Config):
|
|
super().__init__(config)
|
|
self.layers = nn.ModuleList(
|
|
[Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
)
|
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
|
self.embedding_dropout = config.embedding_dropout
|
|
|
|
@check_model_inputs
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> BaseModelOutputWithPast:
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
if use_cache and past_key_values is None:
|
|
past_key_values = DynamicCache()
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
)
|
|
|
|
if position_ids is None:
|
|
position_ids = cache_position.unsqueeze(0)
|
|
|
|
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
|
causal_mask = mask_function(
|
|
config=self.config,
|
|
input_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
position_ids=position_ids,
|
|
)
|
|
|
|
hidden_states = inputs_embeds
|
|
hidden_states = nn.functional.dropout(
|
|
hidden_states, p=self.embedding_dropout, training=self.training
|
|
) # main diff with Llama
|
|
|
|
# create position embeddings to be shared across the decoder layers
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
hidden_states = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_embeddings=position_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=past_key_values if use_cache else None,
|
|
)
|
|
|
|
|
|
class Starcoder2ForCausalLM(MistralForCausalLM):
|
|
pass
|
|
|
|
|
|
class Starcoder2ForSequenceClassification(MistralForSequenceClassification):
|
|
pass
|
|
|
|
|
|
class Starcoder2ForTokenClassification(MistralForTokenClassification):
|
|
pass
|
|
|
|
|
|
__all__ = [
|
|
"Starcoder2ForCausalLM",
|
|
"Starcoder2Model",
|
|
"Starcoder2PreTrainedModel", # noqa: F822
|
|
"Starcoder2ForSequenceClassification",
|
|
"Starcoder2ForTokenClassification",
|
|
]
|