466 lines
18 KiB
Python
466 lines
18 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch Mixtral model."""
|
|
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
|
|
from ...activations import ACT2FN
|
|
from ...cache_utils import Cache, DynamicCache
|
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
from ...processing_utils import Unpack
|
|
from ...utils import TransformersKwargs, logging
|
|
from ...utils.generic import OutputRecorder
|
|
from ..mistral.modeling_mistral import (
|
|
MistralAttention,
|
|
MistralForCausalLM,
|
|
MistralForQuestionAnswering,
|
|
MistralForSequenceClassification,
|
|
MistralForTokenClassification,
|
|
MistralModel,
|
|
MistralPreTrainedModel,
|
|
MistralRMSNorm,
|
|
MistralRotaryEmbedding,
|
|
)
|
|
from .configuration_mixtral import MixtralConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def load_balancing_loss_func(
|
|
gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
|
|
num_experts: Optional[int] = None,
|
|
top_k=2,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, int]:
|
|
r"""
|
|
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
|
|
|
See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
|
|
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
|
experts is too unbalanced.
|
|
|
|
Args:
|
|
gate_logits:
|
|
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
|
shape [batch_size X sequence_length, num_experts].
|
|
num_experts:
|
|
Number of experts
|
|
top_k:
|
|
The number of experts to route per-token, can be also interpreted as the `top-k` routing
|
|
parameter.
|
|
attention_mask (`torch.Tensor`, *optional*):
|
|
The attention_mask used in forward function
|
|
shape [batch_size X sequence_length] if not None.
|
|
|
|
Returns:
|
|
The auxiliary loss.
|
|
"""
|
|
if gate_logits is None or not isinstance(gate_logits, tuple):
|
|
return 0
|
|
|
|
if isinstance(gate_logits, tuple):
|
|
compute_device = gate_logits[0].device
|
|
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
|
|
|
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
|
|
|
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
|
|
|
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
|
|
|
if attention_mask is None:
|
|
# Compute the percentage of tokens routed to each experts
|
|
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
|
|
|
# Compute the average probability of routing to these experts
|
|
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
|
else:
|
|
batch_size, sequence_length = attention_mask.shape
|
|
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
|
|
|
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
|
expert_attention_mask = (
|
|
attention_mask[None, :, :, None, None]
|
|
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
|
.reshape(-1, top_k, num_experts)
|
|
.to(compute_device)
|
|
)
|
|
|
|
# Compute the percentage of tokens routed to each experts
|
|
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
|
expert_attention_mask, dim=0
|
|
)
|
|
|
|
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
|
router_per_expert_attention_mask = (
|
|
attention_mask[None, :, :, None]
|
|
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
|
.reshape(-1, num_experts)
|
|
.to(compute_device)
|
|
)
|
|
|
|
# Compute the average probability of routing to these experts
|
|
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
|
router_per_expert_attention_mask, dim=0
|
|
)
|
|
|
|
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
|
return overall_loss * num_experts
|
|
|
|
|
|
class MixtralBlockSparseTop2MLP(nn.Module):
|
|
def __init__(self, config: MixtralConfig):
|
|
super().__init__()
|
|
self.ffn_dim = config.intermediate_size
|
|
self.hidden_dim = config.hidden_size
|
|
|
|
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
|
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
|
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
|
|
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
|
|
def forward(self, hidden_states):
|
|
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
|
current_hidden_states = self.w2(current_hidden_states)
|
|
return current_hidden_states
|
|
|
|
|
|
class MixtralSparseMoeBlock(nn.Module):
|
|
"""
|
|
This implementation is
|
|
strictly equivalent to standard MoE with full capacity (no
|
|
dropped tokens). It's faster since it formulates MoE operations
|
|
in terms of block-sparse operations to accommodate imbalanced
|
|
assignments of tokens to experts, whereas standard MoE either
|
|
(1) drop tokens at the cost of reduced performance or (2) set
|
|
capacity factor to number of experts and thus waste computation
|
|
and memory on padding.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.hidden_dim = config.hidden_size
|
|
self.ffn_dim = config.intermediate_size
|
|
self.num_experts = config.num_local_experts
|
|
self.top_k = config.num_experts_per_tok
|
|
|
|
# gating
|
|
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
|
|
|
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
|
|
|
# Jitter parameters
|
|
self.jitter_noise = config.router_jitter_noise
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
""" """
|
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
if self.training and self.jitter_noise > 0:
|
|
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
|
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
# router_logits: (batch * sequence_length, n_experts)
|
|
router_logits = self.gate(hidden_states)
|
|
|
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
|
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
# we cast back to the input dtype
|
|
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
|
|
final_hidden_states = torch.zeros(
|
|
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
|
)
|
|
|
|
# One hot encode the selected experts to create an expert mask
|
|
# this will be used to easily index which expert is going to be sollicitated
|
|
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
|
|
|
expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
|
for expert_idx in expert_hitted:
|
|
expert_layer = self.experts[expert_idx]
|
|
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
|
# Index the correct hidden states and compute the expert hidden state for
|
|
# the current expert. We need to make sure to multiply the output hidden
|
|
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
|
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
|
|
|
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
# the `top_x` tensor here.
|
|
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
|
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
|
return final_hidden_states, router_logits
|
|
|
|
|
|
class MixtralRMSNorm(MistralRMSNorm):
|
|
pass
|
|
|
|
|
|
class MixtralAttention(MistralAttention):
|
|
pass
|
|
|
|
|
|
class MixtralDecoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: MixtralConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
|
|
self.self_attn = MixtralAttention(config, layer_idx)
|
|
|
|
self.block_sparse_moe = MixtralSparseMoeBlock(config)
|
|
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.FloatTensor:
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
hidden_states, _ = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
position_embeddings=position_embeddings,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states, _ = self.block_sparse_moe(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
return hidden_states
|
|
|
|
|
|
class MixtralRotaryEmbedding(MistralRotaryEmbedding):
|
|
pass
|
|
|
|
|
|
class MixtralPreTrainedModel(MistralPreTrainedModel):
|
|
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
|
_can_record_outputs = {
|
|
"router_logits": OutputRecorder(MixtralSparseMoeBlock, index=1),
|
|
"hidden_states": MixtralDecoderLayer,
|
|
"attentions": MixtralAttention,
|
|
}
|
|
|
|
|
|
class MixtralModel(MistralModel):
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Cache] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> MoeModelOutputWithPast:
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if use_cache and past_key_values is None:
|
|
past_key_values = DynamicCache()
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
)
|
|
if position_ids is None:
|
|
position_ids = cache_position.unsqueeze(0)
|
|
|
|
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
|
causal_mask = mask_function(
|
|
config=self.config,
|
|
input_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
position_ids=position_ids,
|
|
)
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
# create position embeddings to be shared across the decoder layers
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
hidden_states = decoder_layer(
|
|
hidden_states,
|
|
position_embeddings=position_embeddings,
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=past_key_values,
|
|
)
|
|
|
|
|
|
class MixtralForCausalLM(MistralForCausalLM):
|
|
_tied_weights_keys = ["lm_head.weight"]
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = MixtralModel(config)
|
|
self.router_aux_loss_coef = config.router_aux_loss_coef
|
|
self.num_experts = config.num_local_experts
|
|
self.num_experts_per_tok = config.num_experts_per_tok
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Cache] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_router_logits: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> MoeCausalLMOutputWithPast:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, MixtralForCausalLM
|
|
|
|
>>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
|
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
```"""
|
|
|
|
output_router_logits = (
|
|
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
)
|
|
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
outputs: MoeModelOutputWithPast = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_router_logits=output_router_logits,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = outputs.last_hidden_state
|
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
|
|
|
aux_loss = None
|
|
if output_router_logits:
|
|
aux_loss = load_balancing_loss_func(
|
|
outputs.router_logits,
|
|
self.num_experts,
|
|
self.num_experts_per_tok,
|
|
attention_mask,
|
|
)
|
|
if labels is not None:
|
|
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
|
|
|
return MoeCausalLMOutputWithPast(
|
|
loss=loss,
|
|
aux_loss=aux_loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
router_logits=outputs.router_logits,
|
|
)
|
|
|
|
|
|
class MixtralForSequenceClassification(MistralForSequenceClassification):
|
|
pass
|
|
|
|
|
|
class MixtralForTokenClassification(MistralForTokenClassification):
|
|
pass
|
|
|
|
|
|
class MixtralForQuestionAnswering(MistralForQuestionAnswering):
|
|
pass
|
|
|
|
|
|
__all__ = [
|
|
"MixtralForCausalLM",
|
|
"MixtralForQuestionAnswering",
|
|
"MixtralModel",
|
|
"MixtralPreTrainedModel",
|
|
"MixtralForSequenceClassification",
|
|
"MixtralForTokenClassification",
|
|
]
|