354 lines
16 KiB
Python
354 lines
16 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2024 Cohere team. All rights reserved.
|
||
|
#
|
||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||
|
# and OPT implementations in this library. It has been modified from its
|
||
|
# original forms to accommodate minor architectural differences compared
|
||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
# This file is based on the LLama model definition file in transformers
|
||
|
|
||
|
"""PyTorch Cohere model."""
|
||
|
|
||
|
from typing import Callable, Optional, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.utils.checkpoint
|
||
|
from torch import nn
|
||
|
|
||
|
from ...cache_utils import Cache
|
||
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||
|
from ...modeling_rope_utils import dynamic_rope_update
|
||
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||
|
from ...processing_utils import Unpack
|
||
|
from ...utils import TransformersKwargs, logging
|
||
|
from ..llama.modeling_llama import (
|
||
|
LlamaAttention,
|
||
|
LlamaForCausalLM,
|
||
|
LlamaMLP,
|
||
|
LlamaModel,
|
||
|
LlamaRotaryEmbedding,
|
||
|
eager_attention_forward,
|
||
|
)
|
||
|
from .configuration_cohere import CohereConfig
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
|
||
|
class CohereLayerNorm(nn.Module):
|
||
|
def __init__(self, hidden_size=None, eps=1e-5, bias=False):
|
||
|
"""The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
|
||
|
super().__init__()
|
||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||
|
self.variance_epsilon = eps
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
input_dtype = hidden_states.dtype
|
||
|
hidden_states = hidden_states.to(torch.float32)
|
||
|
mean = hidden_states.mean(-1, keepdim=True)
|
||
|
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
||
|
hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
|
||
|
hidden_states = self.weight.to(torch.float32) * hidden_states
|
||
|
return hidden_states.to(input_dtype)
|
||
|
|
||
|
|
||
|
class CohereRotaryEmbedding(LlamaRotaryEmbedding):
|
||
|
@torch.no_grad()
|
||
|
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||
|
def forward(self, x, position_ids):
|
||
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||
|
position_ids_expanded = position_ids[:, None, :].float()
|
||
|
|
||
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||
|
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||
|
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
|
||
|
cos = emb.cos() * self.attention_scaling
|
||
|
sin = emb.sin() * self.attention_scaling
|
||
|
|
||
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||
|
|
||
|
|
||
|
def rotate_half(x):
|
||
|
# Split and rotate. Note that this function is different from e.g. Llama.
|
||
|
x1 = x[..., ::2]
|
||
|
x2 = x[..., 1::2]
|
||
|
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
|
||
|
return rot_x
|
||
|
|
||
|
|
||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||
|
|
||
|
Args:
|
||
|
q (`torch.Tensor`): The query tensor.
|
||
|
k (`torch.Tensor`): The key tensor.
|
||
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||
|
position_ids (`torch.Tensor`, *optional*):
|
||
|
Deprecated and unused.
|
||
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||
|
Returns:
|
||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||
|
"""
|
||
|
dtype = q.dtype
|
||
|
q = q.float()
|
||
|
k = k.float()
|
||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||
|
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
|
||
|
|
||
|
|
||
|
class CohereMLP(LlamaMLP):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||
|
|
||
|
|
||
|
class CohereAttention(LlamaAttention):
|
||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||
|
|
||
|
def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
|
||
|
super().__init__(config, layer_idx)
|
||
|
self.use_qk_norm = config.use_qk_norm
|
||
|
if self.use_qk_norm:
|
||
|
# When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
|
||
|
self.q_norm = CohereLayerNorm(
|
||
|
hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps
|
||
|
)
|
||
|
self.k_norm = CohereLayerNorm(
|
||
|
hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||
|
attention_mask: Optional[torch.Tensor],
|
||
|
past_key_value: Optional[Cache] = None,
|
||
|
cache_position: Optional[torch.LongTensor] = None,
|
||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||
|
input_shape = hidden_states.shape[:-1]
|
||
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
||
|
|
||
|
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
||
|
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||
|
value_states = self.v_proj(hidden_states).view(hidden_shape)
|
||
|
|
||
|
if self.use_qk_norm: # main diff from Llama
|
||
|
query_states = self.q_norm(query_states)
|
||
|
key_states = self.k_norm(key_states)
|
||
|
|
||
|
query_states = query_states.transpose(1, 2)
|
||
|
key_states = key_states.transpose(1, 2)
|
||
|
value_states = value_states.transpose(1, 2)
|
||
|
|
||
|
cos, sin = position_embeddings
|
||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||
|
|
||
|
if past_key_value is not None:
|
||
|
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||
|
|
||
|
attention_interface: Callable = eager_attention_forward
|
||
|
if self.config._attn_implementation != "eager":
|
||
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||
|
|
||
|
attn_output, attn_weights = attention_interface(
|
||
|
self,
|
||
|
query_states,
|
||
|
key_states,
|
||
|
value_states,
|
||
|
attention_mask,
|
||
|
dropout=0.0 if not self.training else self.attention_dropout,
|
||
|
scaling=self.scaling,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||
|
attn_output = self.o_proj(attn_output)
|
||
|
return attn_output, attn_weights
|
||
|
|
||
|
|
||
|
class CohereDecoderLayer(GradientCheckpointingLayer):
|
||
|
def __init__(self, config: CohereConfig, layer_idx: int):
|
||
|
super().__init__()
|
||
|
self.hidden_size = config.hidden_size
|
||
|
self.self_attn = CohereAttention(config=config, layer_idx=layer_idx)
|
||
|
self.mlp = CohereMLP(config)
|
||
|
self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
past_key_value: Optional[Cache] = None,
|
||
|
use_cache: Optional[bool] = False,
|
||
|
cache_position: Optional[torch.LongTensor] = None,
|
||
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||
|
"""
|
||
|
Args:
|
||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||
|
attention_mask (`torch.FloatTensor`, *optional*):
|
||
|
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||
|
query_sequence_length, key_sequence_length)` if default attention is used.
|
||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||
|
returned tensors for more detail.
|
||
|
use_cache (`bool`, *optional*):
|
||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||
|
(see `past_key_values`).
|
||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||
|
Indices depicting the position of the input sequence tokens in the sequence
|
||
|
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||
|
with `head_dim` being the embedding dimension of each attention head.
|
||
|
"""
|
||
|
residual = hidden_states
|
||
|
hidden_states = self.input_layernorm(hidden_states)
|
||
|
|
||
|
hidden_states_attention, _ = self.self_attn(
|
||
|
hidden_states=hidden_states,
|
||
|
attention_mask=attention_mask,
|
||
|
position_ids=position_ids,
|
||
|
past_key_value=past_key_value,
|
||
|
use_cache=use_cache,
|
||
|
cache_position=cache_position,
|
||
|
position_embeddings=position_embeddings,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
hidden_states_mlp = self.mlp(hidden_states)
|
||
|
hidden_states = residual + hidden_states_attention + hidden_states_mlp
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class CohereModel(LlamaModel):
|
||
|
def __init__(self, config: CohereConfig):
|
||
|
super().__init__(config)
|
||
|
self.layers = nn.ModuleList(
|
||
|
[CohereDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||
|
)
|
||
|
self.rotary_emb = CohereRotaryEmbedding(config=config)
|
||
|
self.norm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
|
||
|
|
||
|
|
||
|
class CohereForCausalLM(LlamaForCausalLM):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
self.model = CohereModel(config)
|
||
|
self.logit_scale = config.logit_scale
|
||
|
self.tie_word_embeddings = config.tie_word_embeddings
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: Optional[torch.LongTensor] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
labels: Optional[torch.LongTensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
cache_position: Optional[torch.LongTensor] = None,
|
||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||
|
**kwargs: Unpack[TransformersKwargs],
|
||
|
) -> CausalLMOutputWithPast:
|
||
|
r"""
|
||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
>> from transformers import AutoTokenizer, CohereForCausalLM
|
||
|
|
||
|
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||
|
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||
|
|
||
|
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||
|
>> inputs = tokenizer(prompt, return_tensors="pt")
|
||
|
|
||
|
>> # Generate
|
||
|
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||
|
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||
|
```"""
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
|
||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||
|
outputs: BaseModelOutputWithPast = self.model(
|
||
|
input_ids=input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
position_ids=position_ids,
|
||
|
past_key_values=past_key_values,
|
||
|
inputs_embeds=inputs_embeds,
|
||
|
use_cache=use_cache,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
cache_position=cache_position,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
hidden_states = outputs.last_hidden_state
|
||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||
|
logits = logits * self.logit_scale # main diff from Llama
|
||
|
|
||
|
loss = None
|
||
|
if labels is not None:
|
||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||
|
|
||
|
return CausalLMOutputWithPast(
|
||
|
loss=loss,
|
||
|
logits=logits,
|
||
|
past_key_values=outputs.past_key_values,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
__all__ = [
|
||
|
"CohereForCausalLM",
|
||
|
"CohereModel",
|
||
|
"CoherePreTrainedModel", # noqa: F822
|
||
|
]
|