198 lines
7.8 KiB
Python
198 lines
7.8 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import Optional, TypedDict
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ...activations import ACT2FN
|
|
from ...cache_utils import Cache
|
|
from ...processing_utils import Unpack
|
|
from ...utils import logging
|
|
from ..granitemoe.modeling_granitemoe import (
|
|
GraniteMoeDecoderLayer,
|
|
GraniteMoeForCausalLM,
|
|
GraniteMoeModel,
|
|
GraniteMoePreTrainedModel,
|
|
)
|
|
from .configuration_granitemoeshared import GraniteMoeSharedConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class GraniteFlashAttentionKwargs(TypedDict, total=False):
|
|
"""
|
|
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
|
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
|
|
|
Attributes:
|
|
cu_seq_lens_q (`torch.LongTensor`)
|
|
Gets cumulative sequence length for query state.
|
|
cu_seq_lens_k (`torch.LongTensor`)
|
|
Gets cumulative sequence length for key state.
|
|
max_length_q (`int`):
|
|
Maximum sequence length for query state.
|
|
max_length_k (`int`):
|
|
Maximum sequence length for key state.
|
|
seq_idx (`torch.IntTensor):
|
|
Index of each packed sequence.
|
|
"""
|
|
|
|
cu_seq_lens_q: torch.LongTensor
|
|
cu_seq_lens_k: torch.LongTensor
|
|
max_length_q: int
|
|
max_length_k: int
|
|
seq_idx: torch.IntTensor
|
|
|
|
|
|
class GraniteMoeSharedMLP(nn.Module):
|
|
"""
|
|
MLP layer for shared experts
|
|
|
|
Args:
|
|
config:
|
|
Configuration object with model hyperparameters.
|
|
"""
|
|
|
|
def __init__(self, config: GraniteMoeSharedConfig):
|
|
super().__init__()
|
|
|
|
self.input_size = config.hidden_size
|
|
self.hidden_size = config.shared_intermediate_size
|
|
self.activation = ACT2FN[config.hidden_act]
|
|
self.input_linear = nn.Linear(self.input_size, self.hidden_size * 2, bias=False)
|
|
self.output_linear = nn.Linear(self.hidden_size, self.input_size, bias=False)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.input_linear(hidden_states)
|
|
chunked_hidden_states = hidden_states.chunk(2, dim=-1)
|
|
hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
|
|
hidden_states = self.output_linear(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
|
|
def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int):
|
|
super().__init__(config, layer_idx)
|
|
self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
use_cache: Optional[bool] = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
output_router_logits: Optional[bool] = False,
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`, *optional*):
|
|
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
|
query_sequence_length, key_sequence_length)` if default attention is used.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
use_cache (`bool`, *optional*):
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
(see `past_key_values`).
|
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
|
Indices depicting the position of the input sequence tokens in the sequence
|
|
output_router_logits (`bool`, *optional*):
|
|
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
|
should not be returned during inference.
|
|
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
|
with `head_dim` being the embedding dimension of each attention head.
|
|
kwargs (`dict`, *optional*):
|
|
Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for
|
|
padding-free training and/or improve torch.compile performance.
|
|
"""
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
hidden_states, self_attn_weights = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_embeddings=position_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = residual + hidden_states * self.residual_multiplier
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
|
|
|
if self.shared_mlp is None:
|
|
hidden_states = moe_hidden_states
|
|
else:
|
|
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
|
|
|
|
del moe_hidden_states
|
|
|
|
hidden_states = residual + hidden_states * self.residual_multiplier
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights,)
|
|
|
|
if output_router_logits:
|
|
outputs += (router_logits,)
|
|
|
|
return outputs
|
|
|
|
|
|
class GraniteMoeSharedPreTrainedModel(GraniteMoePreTrainedModel):
|
|
config: GraniteMoeSharedConfig
|
|
_no_split_modules = ["GraniteMoeSharedDecoderLayer"]
|
|
|
|
|
|
class GraniteMoeSharedModel(GraniteMoeModel):
|
|
def __init__(self, config: GraniteMoeSharedConfig):
|
|
super().__init__(config)
|
|
self.layers = nn.ModuleList(
|
|
[GraniteMoeSharedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
)
|
|
|
|
|
|
class GraniteMoeSharedForCausalLM(GraniteMoeForCausalLM):
|
|
_tied_weights_keys = ["lm_head.weight"]
|
|
|
|
def __init__(self, config: GraniteMoeSharedConfig):
|
|
super().__init__(config)
|
|
self.model = GraniteMoeSharedModel(config)
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
|
|
__all__ = ["GraniteMoeSharedForCausalLM", "GraniteMoeSharedModel", "GraniteMoeSharedPreTrainedModel"]
|