298 lines
11 KiB
Python
298 lines
11 KiB
Python
from typing import Callable, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from ...cache_utils import Cache, DynamicCache
|
|
from ...masking_utils import create_causal_mask
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import (
|
|
BaseModelOutputWithPast,
|
|
)
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
from ...processing_utils import Unpack
|
|
from ...utils import TransformersKwargs, logging
|
|
from ..clip.modeling_clip import CLIPMLP
|
|
from ..llama.modeling_llama import (
|
|
LlamaAttention,
|
|
LlamaForCausalLM,
|
|
LlamaForSequenceClassification,
|
|
LlamaForTokenClassification,
|
|
LlamaModel,
|
|
LlamaRotaryEmbedding,
|
|
apply_rotary_pos_emb,
|
|
eager_attention_forward, # copied from Llama
|
|
)
|
|
from .configuration_phi import PhiConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
|
|
_CONFIG_FOR_DOC = "PhiConfig"
|
|
|
|
|
|
class PhiAttention(LlamaAttention):
|
|
def __init__(self, config: PhiConfig, layer_idx: int):
|
|
super().__init__(config, layer_idx)
|
|
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
|
|
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
|
|
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
|
|
self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True)
|
|
del self.o_proj
|
|
self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
|
|
self.qk_layernorm = config.qk_layernorm
|
|
if self.qk_layernorm:
|
|
self.q_layernorm = nn.LayerNorm(
|
|
config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
|
)
|
|
self.k_layernorm = nn.LayerNorm(
|
|
config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor],
|
|
past_key_value: Optional[Cache] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
input_shape = hidden_states.shape[:-1]
|
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
|
if self.qk_layernorm:
|
|
query_states = self.q_layernorm(query_states)
|
|
key_states = self.k_layernorm(key_states)
|
|
|
|
cos, sin = position_embeddings
|
|
# Partial rotary embedding
|
|
query_rot, query_pass = (
|
|
query_states[..., : self.rotary_ndims],
|
|
query_states[..., self.rotary_ndims :],
|
|
)
|
|
key_rot, key_pass = (
|
|
key_states[..., : self.rotary_ndims],
|
|
key_states[..., self.rotary_ndims :],
|
|
)
|
|
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
|
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
|
|
|
# [batch_size, seq_length, num_heads, head_dim]
|
|
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
|
|
if past_key_value is not None:
|
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
if self.config._attn_implementation != "eager":
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scaling,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
attn_output = self.dense(attn_output)
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class PhiMLP(CLIPMLP):
|
|
pass
|
|
|
|
|
|
class PhiDecoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: PhiConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.self_attn = PhiAttention(config, layer_idx=layer_idx)
|
|
self.mlp = PhiMLP(config)
|
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
use_cache: Optional[bool] = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
|
**kwargs,
|
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
attn_outputs, self_attn_weights = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_embeddings=position_embeddings,
|
|
**kwargs,
|
|
)
|
|
attn_outputs = self.resid_dropout(attn_outputs)
|
|
|
|
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
|
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights,)
|
|
|
|
return outputs
|
|
|
|
|
|
class PhiRotaryEmbedding(LlamaRotaryEmbedding):
|
|
pass
|
|
|
|
|
|
class PhiModel(LlamaModel):
|
|
def __init__(self, config: PhiConfig):
|
|
super().__init__(config)
|
|
self.layers = nn.ModuleList(
|
|
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
)
|
|
self.embed_dropout = nn.Dropout(config.embd_pdrop)
|
|
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
del self.norm
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Cache] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> BaseModelOutputWithPast:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if self.gradient_checkpointing and self.training and use_cache:
|
|
logger.warning_once(
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
)
|
|
use_cache = False
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
if use_cache and past_key_values is None:
|
|
past_key_values = DynamicCache()
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
)
|
|
|
|
if position_ids is None:
|
|
position_ids = cache_position.unsqueeze(0)
|
|
|
|
causal_mask = create_causal_mask(
|
|
config=self.config,
|
|
input_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
position_ids=position_ids,
|
|
)
|
|
|
|
inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama
|
|
hidden_states = inputs_embeds
|
|
|
|
# create position embeddings to be shared across the decoder layers
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
|
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_embeddings=position_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
hidden_states = self.final_layernorm(hidden_states) # diff with Llama
|
|
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=past_key_values if use_cache else None,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
)
|
|
|
|
|
|
class PhiForCausalLM(LlamaForCausalLM):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
|
|
|
|
|
class PhiForSequenceClassification(LlamaForSequenceClassification):
|
|
pass
|
|
|
|
|
|
class PhiForTokenClassification(LlamaForTokenClassification):
|
|
pass
|
|
|
|
|
|
__all__ = [
|
|
"PhiPreTrainedModel", # noqa: F822
|
|
"PhiModel",
|
|
"PhiForCausalLM",
|
|
"PhiForSequenceClassification",
|
|
"PhiForTokenClassification",
|
|
]
|