1242 lines
52 KiB
Python
1242 lines
52 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
|
|
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch ESM model."""
|
|
|
|
import math
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import (
|
|
BaseModelOutputWithCrossAttentions,
|
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
MaskedLMOutput,
|
|
SequenceClassifierOutput,
|
|
TokenClassifierOutput,
|
|
)
|
|
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
|
from ...utils import auto_docstring, can_return_tuple, logging
|
|
from ...utils.deprecation import deprecate_kwarg
|
|
from .configuration_esm import EsmConfig
|
|
|
|
|
|
if is_flash_attn_available():
|
|
from ...modeling_flash_attention_utils import _flash_attention_forward
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def rotate_half(x):
|
|
x1, x2 = x.chunk(2, dim=-1)
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
def apply_rotary_pos_emb(x, cos, sin):
|
|
cos = cos[:, :, : x.shape[-2], :]
|
|
sin = sin[:, :, : x.shape[-2], :]
|
|
|
|
return (x * cos) + (rotate_half(x) * sin)
|
|
|
|
|
|
def gelu(x):
|
|
"""
|
|
This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results.
|
|
"""
|
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
|
|
|
|
|
def symmetrize(x):
|
|
"Make layer symmetric in final two dimensions, used for contact prediction."
|
|
return x + x.transpose(-1, -2)
|
|
|
|
|
|
def average_product_correct(x):
|
|
"Perform average product correct, used for contact prediction."
|
|
a1 = x.sum(-1, keepdims=True)
|
|
a2 = x.sum(-2, keepdims=True)
|
|
a12 = x.sum((-1, -2), keepdims=True)
|
|
|
|
avg = a1 * a2
|
|
avg.div_(a12) # in-place to reduce memory
|
|
normalized = x - avg
|
|
return normalized
|
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module):
|
|
"""
|
|
Rotary position embeddings based on those in
|
|
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
|
|
matrices which depend on their relative positions.
|
|
"""
|
|
|
|
def __init__(self, dim: int):
|
|
super().__init__()
|
|
# Generate and save the inverse frequency buffer (non trainable)
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
|
inv_freq = inv_freq
|
|
self.register_buffer("inv_freq", inv_freq)
|
|
|
|
self._seq_len_cached = None
|
|
self._cos_cached = None
|
|
self._sin_cached = None
|
|
|
|
def _update_cos_sin_tables(self, x, seq_dimension=2):
|
|
seq_len = x.shape[seq_dimension]
|
|
|
|
# Reset the tables if the sequence length has changed,
|
|
# or if we're on a new device (possibly due to tracing for instance)
|
|
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
|
|
self._seq_len_cached = seq_len
|
|
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
|
|
freqs = torch.outer(t, self.inv_freq)
|
|
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
|
|
|
self._cos_cached = emb.cos()[None, None, :, :]
|
|
self._sin_cached = emb.sin()[None, None, :, :]
|
|
|
|
return self._cos_cached, self._sin_cached
|
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
|
|
|
|
return (
|
|
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
|
|
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
|
|
)
|
|
|
|
|
|
class EsmContactPredictionHead(nn.Module):
|
|
"""Performs symmetrization, apc, and computes a logistic regression on the output features"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
bias=True,
|
|
eos_idx: int = 2,
|
|
):
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.eos_idx = eos_idx
|
|
self.regression = nn.Linear(in_features, 1, bias)
|
|
self.activation = nn.Sigmoid()
|
|
|
|
def forward(self, tokens, attentions):
|
|
# remove eos token attentions
|
|
eos_mask = tokens.ne(self.eos_idx).to(attentions)
|
|
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
|
|
attentions = attentions * eos_mask[:, None, None, :, :]
|
|
attentions = attentions[..., :-1, :-1]
|
|
# remove cls token attentions
|
|
attentions = attentions[..., 1:, 1:]
|
|
batch_size, layers, heads, seqlen, _ = attentions.size()
|
|
attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
|
|
|
|
# features: batch x channels x tokens x tokens (symmetric)
|
|
attentions = attentions.to(
|
|
self.regression.weight.device
|
|
) # attentions always float32, may need to convert to float16
|
|
attentions = average_product_correct(symmetrize(attentions))
|
|
attentions = attentions.permute(0, 2, 3, 1)
|
|
return self.activation(self.regression(attentions).squeeze(3))
|
|
|
|
|
|
class EsmEmbeddings(nn.Module):
|
|
"""
|
|
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
|
|
|
if config.emb_layer_norm_before:
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
else:
|
|
self.layer_norm = None
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
|
self.register_buffer(
|
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
|
)
|
|
|
|
self.padding_idx = config.pad_token_id
|
|
if self.position_embedding_type == "absolute":
|
|
self.position_embeddings = nn.Embedding(
|
|
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
|
|
)
|
|
self.token_dropout = config.token_dropout
|
|
self.mask_token_id = config.mask_token_id
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
position_ids=None,
|
|
inputs_embeds=None,
|
|
):
|
|
if position_ids is None:
|
|
if input_ids is not None:
|
|
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
|
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
|
|
else:
|
|
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
# Note that if we want to support ESM-1 (not 1b!) in future then we need to support an
|
|
# embedding_scale factor here.
|
|
embeddings = inputs_embeds
|
|
|
|
# Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
|
|
# flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
|
|
# masked tokens are treated as if they were selected for input dropout and zeroed out.
|
|
# This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
|
|
# a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
|
|
# This is analogous to the way that dropout layers scale down outputs during evaluation when not
|
|
# actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
|
|
if self.token_dropout:
|
|
embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
|
|
mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
|
|
src_lengths = attention_mask.sum(-1)
|
|
mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
|
|
embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
|
|
embeddings.dtype
|
|
)
|
|
|
|
if self.position_embedding_type == "absolute":
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
embeddings = embeddings + position_embeddings
|
|
|
|
if self.layer_norm is not None:
|
|
embeddings = self.layer_norm(embeddings)
|
|
if attention_mask is not None:
|
|
embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
|
|
# Matt: I think this line was copied incorrectly from BERT, disabling it for now.
|
|
# embeddings = self.dropout(embeddings)
|
|
return embeddings
|
|
|
|
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
|
"""
|
|
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
|
|
|
Args:
|
|
inputs_embeds: torch.Tensor
|
|
|
|
Returns: torch.Tensor
|
|
"""
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
sequence_length = input_shape[1]
|
|
|
|
position_ids = torch.arange(
|
|
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
|
|
)
|
|
return position_ids.unsqueeze(0).expand(input_shape)
|
|
|
|
|
|
class EsmSelfAttention(nn.Module):
|
|
def __init__(self, config, position_embedding_type=None, layer_idx=None):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
|
raise ValueError(
|
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
|
f"heads ({config.num_attention_heads})"
|
|
)
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
|
self.position_embedding_type = position_embedding_type or getattr(
|
|
config, "position_embedding_type", "absolute"
|
|
)
|
|
self.rotary_embeddings = None
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
|
elif self.position_embedding_type == "rotary":
|
|
self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
|
|
|
|
self.is_decoder = config.is_decoder
|
|
self.layer_idx = layer_idx
|
|
|
|
@deprecate_kwarg("past_key_value", version="4.54.0")
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> tuple[torch.Tensor]:
|
|
hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
|
|
|
|
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
|
# If this is instantiated as a cross-attention module, the keys
|
|
# and values come from an encoder; the attention mask needs to be
|
|
# such that the encoder's padding tokens are not attended to.
|
|
is_cross_attention = encoder_hidden_states is not None
|
|
|
|
if is_cross_attention:
|
|
key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
|
value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
|
attention_mask = encoder_attention_mask
|
|
else:
|
|
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
|
# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
|
|
# ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
|
|
# but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
|
|
# ESM code and fix rotary embeddings.
|
|
query_layer = query_layer * self.attention_head_size**-0.5
|
|
|
|
if self.position_embedding_type == "rotary":
|
|
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
|
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
|
seq_length = hidden_states.size()[1]
|
|
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
|
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
|
distance = position_ids_l - position_ids_r
|
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
|
|
|
if self.position_embedding_type == "relative_key":
|
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
|
attention_scores = attention_scores + relative_position_scores
|
|
elif self.position_embedding_type == "relative_key_query":
|
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
|
|
|
if attention_mask is not None:
|
|
# Apply the attention mask is (precomputed for all layers in EsmModel forward() function)
|
|
attention_scores = attention_scores + attention_mask
|
|
|
|
# Normalize the attention scores to probabilities.
|
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
attention_probs = self.dropout(attention_probs)
|
|
|
|
# Mask heads if we want to
|
|
if head_mask is not None:
|
|
attention_probs = attention_probs * head_mask
|
|
|
|
context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer)
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
context_layer = context_layer.view(new_context_layer_shape)
|
|
|
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
|
|
|
if self.is_decoder:
|
|
outputs = outputs + (None,)
|
|
return outputs
|
|
|
|
|
|
class EsmSelfOutput(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states, input_tensor):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = hidden_states + input_tensor
|
|
return hidden_states
|
|
|
|
|
|
class EsmFlashAttention2(EsmSelfAttention):
|
|
"""
|
|
ESM flash attention module. This module inherits from `EsmSelfAttention` as the weights of the module stays
|
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
|
flash attention and deal with padding tokens in case the input contains any of them.
|
|
"""
|
|
|
|
def __init__(self, config, position_embedding_type=None, layer_idx=None):
|
|
super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx)
|
|
|
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
|
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
|
self.dropout_prob = config.attention_probs_dropout_prob
|
|
|
|
@deprecate_kwarg("past_key_value", version="4.54.0")
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> tuple[torch.Tensor]:
|
|
# Flash attention doesn't support output_attentions or cross attention
|
|
if output_attentions or head_mask is not None or encoder_hidden_states is not None:
|
|
logger.warning_once(
|
|
"EsmFlashAttention2 does not support output_attentions, head_mask, or cross_attention. "
|
|
"Falling back to the manual attention implementation. This warning can be removed using "
|
|
'the argument `attn_implementation="eager"` when loading the model.'
|
|
)
|
|
return super().forward(
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
output_attentions,
|
|
)
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
|
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
|
# in fp32.
|
|
input_dtype = query_layer.dtype
|
|
device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu"
|
|
if input_dtype == torch.float32:
|
|
if torch.is_autocast_enabled():
|
|
target_dtype = (
|
|
torch.get_autocast_dtype(device_type)
|
|
if hasattr(torch, "get_autocast_dtype")
|
|
else torch.get_autocast_gpu_dtype()
|
|
)
|
|
# Handle the case where the model is quantized
|
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
|
target_dtype = self.config._pre_quantization_dtype
|
|
else:
|
|
target_dtype = self.query.weight.dtype
|
|
|
|
logger.warning_once(
|
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
|
f" {target_dtype}."
|
|
)
|
|
|
|
query_layer = query_layer.to(target_dtype)
|
|
key_layer = key_layer.to(target_dtype)
|
|
value_layer = value_layer.to(target_dtype)
|
|
|
|
# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
|
|
# ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
|
|
# but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
|
|
# ESM code and fix rotary embeddings.
|
|
query_layer = query_layer * self.attention_head_size**-0.5
|
|
|
|
if self.position_embedding_type == "rotary":
|
|
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
|
elif self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
|
raise ValueError(f"ESM flash attention does not support {self.position_embedding_type} embeddings")
|
|
|
|
# It would likely be faster to change self.transpose_for_scores to output the correct
|
|
# dimensions for flash_attention_2, but that would also mean changing the rotary embedding
|
|
# functions. Here we just permute the dimensions to match the expected input.
|
|
attn_output = _flash_attention_forward(
|
|
query_layer.permute(0, 2, 1, 3),
|
|
key_layer.permute(0, 2, 1, 3),
|
|
value_layer.permute(0, 2, 1, 3),
|
|
attention_mask,
|
|
query_length=q_len,
|
|
is_causal=self.is_decoder,
|
|
softmax_scale=1.0,
|
|
dropout=self.dropout_prob if self.training else 0.0,
|
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
|
|
|
outputs = (attn_output, None)
|
|
if self.is_decoder:
|
|
outputs = outputs + (None,)
|
|
|
|
return outputs
|
|
|
|
|
|
ESM_ATTENTION_CLASSES = {
|
|
"eager": EsmSelfAttention,
|
|
"flash_attention_2": EsmFlashAttention2,
|
|
}
|
|
|
|
|
|
class EsmAttention(nn.Module):
|
|
def __init__(self, config, layer_idx=None):
|
|
super().__init__()
|
|
self.self = ESM_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
|
self.output = EsmSelfOutput(config)
|
|
self.pruned_heads = set()
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
def prune_heads(self, heads):
|
|
if len(heads) == 0:
|
|
return
|
|
heads, index = find_pruneable_heads_and_indices(
|
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
|
)
|
|
|
|
# Prune linear layers
|
|
self.self.query = prune_linear_layer(self.self.query, index)
|
|
self.self.key = prune_linear_layer(self.self.key, index)
|
|
self.self.value = prune_linear_layer(self.self.value, index)
|
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
|
|
|
# Update hyper params and store pruned heads
|
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
|
self.pruned_heads = self.pruned_heads.union(heads)
|
|
|
|
@deprecate_kwarg("past_key_value", version="4.54.0")
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
past_key_value=None,
|
|
output_attentions=False,
|
|
cache_position=None,
|
|
):
|
|
hidden_states_ln = self.LayerNorm(hidden_states)
|
|
self_outputs = self.self(
|
|
hidden_states_ln,
|
|
attention_mask=attention_mask,
|
|
head_mask=head_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
attention_output = self.output(self_outputs[0], hidden_states)
|
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
class EsmIntermediate(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = gelu(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class EsmOutput(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states, input_tensor):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = hidden_states + input_tensor
|
|
return hidden_states
|
|
|
|
|
|
class EsmLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
|
self.seq_len_dim = 1
|
|
self.attention = EsmAttention(config)
|
|
self.is_decoder = config.is_decoder
|
|
self.add_cross_attention = config.add_cross_attention
|
|
if self.add_cross_attention:
|
|
if not self.is_decoder:
|
|
raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
|
|
self.crossattention = EsmAttention(config)
|
|
self.intermediate = EsmIntermediate(config)
|
|
self.output = EsmOutput(config)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
@deprecate_kwarg("past_key_value", version="4.54.0")
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
past_key_value=None,
|
|
output_attentions=False,
|
|
cache_position=None,
|
|
):
|
|
self_attention_outputs = self.attention(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
head_mask=head_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
attention_output = self_attention_outputs[0]
|
|
|
|
# if decoder, the last output is tuple of self-attn cache
|
|
if self.is_decoder:
|
|
outputs = self_attention_outputs[1:-1]
|
|
else:
|
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
|
|
|
if self.is_decoder and encoder_hidden_states is not None:
|
|
if not hasattr(self, "crossattention"):
|
|
raise AttributeError(
|
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
|
|
" with cross-attention layers by setting `config.add_cross_attention=True`"
|
|
)
|
|
|
|
cross_attention_outputs = self.crossattention(
|
|
attention_output,
|
|
attention_mask=attention_mask,
|
|
head_mask=head_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
attention_output = cross_attention_outputs[0]
|
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
|
|
|
layer_output = self.feed_forward_chunk(attention_output)
|
|
|
|
outputs = (layer_output,) + outputs
|
|
|
|
# if decoder, return the attn key/values as the last output
|
|
if self.is_decoder:
|
|
outputs = outputs + (None,)
|
|
return outputs
|
|
|
|
def feed_forward_chunk(self, attention_output):
|
|
attention_output_ln = self.LayerNorm(attention_output)
|
|
intermediate_output = self.intermediate(attention_output_ln)
|
|
layer_output = self.output(intermediate_output, attention_output)
|
|
return layer_output
|
|
|
|
|
|
class EsmEncoder(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
|
|
self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.gradient_checkpointing = False
|
|
|
|
@deprecate_kwarg("past_key_value", version="4.54.0")
|
|
@deprecate_kwarg("use_cache", version="4.54.0")
|
|
@can_return_tuple
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
past_key_values=None,
|
|
use_cache=None,
|
|
output_attentions=False,
|
|
output_hidden_states=False,
|
|
return_dict=True,
|
|
cache_position=None,
|
|
):
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attentions = () if output_attentions else None
|
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
|
|
|
for i, layer_module in enumerate(self.layer):
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
|
|
layer_outputs = layer_module(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
head_mask=layer_head_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
|
if self.config.add_cross_attention:
|
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
|
|
|
if self.emb_layer_norm_after:
|
|
hidden_states = self.emb_layer_norm_after(hidden_states)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
return BaseModelOutputWithCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
cross_attentions=all_cross_attentions,
|
|
)
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.BertPooler
|
|
class EsmPooler(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.activation = nn.Tanh()
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
# We "pool" the model by simply taking the hidden state corresponding
|
|
# to the first token.
|
|
first_token_tensor = hidden_states[:, 0]
|
|
pooled_output = self.dense(first_token_tensor)
|
|
pooled_output = self.activation(pooled_output)
|
|
return pooled_output
|
|
|
|
|
|
@auto_docstring
|
|
class EsmPreTrainedModel(PreTrainedModel):
|
|
config: EsmConfig
|
|
base_model_prefix = "esm"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
|
|
_keys_to_ignore_on_load_unexpected = ["position_embeddings.weight"]
|
|
_supports_flash_attn = True
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights"""
|
|
if isinstance(module, nn.Linear):
|
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.Embedding):
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
if module.padding_idx is not None:
|
|
module.weight.data[module.padding_idx].zero_()
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
elif isinstance(module, EsmLMHead):
|
|
module.bias.data.zero_()
|
|
|
|
def get_output_embeddings(self):
|
|
# NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
|
|
# See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
|
|
return None
|
|
|
|
|
|
@auto_docstring
|
|
class EsmModel(EsmPreTrainedModel):
|
|
"""
|
|
|
|
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
|
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
|
all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
|
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
|
|
|
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
|
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
|
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
|
"""
|
|
|
|
def __init__(self, config, add_pooling_layer=True):
|
|
r"""
|
|
add_pooling_layer (bool, *optional*, defaults to `True`):
|
|
Whether to add a pooling layer
|
|
"""
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.embeddings = EsmEmbeddings(config)
|
|
self.encoder = EsmEncoder(config)
|
|
|
|
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
|
|
|
self.contact_head = EsmContactPredictionHead(
|
|
in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
|
|
)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embeddings.word_embeddings = value
|
|
|
|
def _prune_heads(self, heads_to_prune):
|
|
"""
|
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
|
class PreTrainedModel
|
|
"""
|
|
for layer, heads in heads_to_prune.items():
|
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
|
|
|
@deprecate_kwarg("past_key_values", version="4.54.0")
|
|
@deprecate_kwarg("use_cache", version="4.54.0")
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
|
r"""
|
|
input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`):
|
|
Indices of input sequence tokens in the vocabulary.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`, *optional*):
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
config.max_position_embeddings - 1]`.
|
|
|
|
[What are position IDs?](../glossary#position-ids)
|
|
inputs_embeds (`torch.FloatTensor` of shape `((batch_size, sequence_length), hidden_size)`, *optional*):
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
|
model's internal embedding lookup matrix.
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
input_shape = input_ids.size()
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
batch_size, seq_length = input_shape
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
|
|
|
if self.config._attn_implementation == "flash_attention_2":
|
|
extended_attention_mask = attention_mask
|
|
|
|
else:
|
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
|
|
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
if self.config.is_decoder and encoder_hidden_states is not None:
|
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
|
if encoder_attention_mask is None:
|
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
else:
|
|
encoder_extended_attention_mask = None
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
|
|
|
embedding_output = self.embeddings(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
encoder_outputs = self.encoder(
|
|
embedding_output,
|
|
attention_mask=extended_attention_mask,
|
|
head_mask=head_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_extended_attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=True,
|
|
)
|
|
sequence_output = encoder_outputs[0]
|
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
|
|
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
cross_attentions=encoder_outputs.cross_attentions,
|
|
)
|
|
|
|
def predict_contacts(self, tokens, attention_mask):
|
|
attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
|
|
attns = torch.stack(attns, dim=1) # Matches the original model layout
|
|
# In the original model, attentions for padding tokens are completely zeroed out.
|
|
# This makes no difference most of the time because the other tokens won't attend to them,
|
|
# but it does for the contact prediction task, which takes attentions as input,
|
|
# so we have to mimic that here.
|
|
attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
|
|
attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
|
|
return self.contact_head(tokens, attns)
|
|
|
|
|
|
@auto_docstring
|
|
class EsmForMaskedLM(EsmPreTrainedModel):
|
|
_tied_weights_keys = ["lm_head.decoder.weight"]
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
if config.is_decoder:
|
|
logger.warning(
|
|
"If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
|
|
"bi-directional self-attention."
|
|
)
|
|
|
|
self.esm = EsmModel(config, add_pooling_layer=False)
|
|
self.lm_head = EsmLMHead(config)
|
|
|
|
self.init_weights()
|
|
|
|
self.post_init()
|
|
|
|
def get_output_embeddings(self):
|
|
return self.lm_head.decoder
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.lm_head.decoder = new_embeddings
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[tuple, MaskedLMOutput]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
|
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.esm(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=True,
|
|
)
|
|
sequence_output = outputs[0]
|
|
prediction_scores = self.lm_head(sequence_output)
|
|
|
|
masked_lm_loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
|
|
labels = labels.to(prediction_scores.device)
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
|
|
|
return MaskedLMOutput(
|
|
loss=masked_lm_loss,
|
|
logits=prediction_scores,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
def predict_contacts(self, tokens, attention_mask):
|
|
return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
|
|
|
|
|
|
class EsmLMHead(nn.Module):
|
|
"""ESM Head for masked language modeling."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
|
|
|
def forward(self, features, **kwargs):
|
|
x = self.dense(features)
|
|
x = gelu(x)
|
|
x = self.layer_norm(x)
|
|
|
|
# project back to size of vocabulary with bias
|
|
x = self.decoder(x) + self.bias
|
|
return x
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
|
output) e.g. for GLUE tasks.
|
|
"""
|
|
)
|
|
class EsmForSequenceClassification(EsmPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
self.config = config
|
|
|
|
self.esm = EsmModel(config, add_pooling_layer=False)
|
|
self.classifier = EsmClassificationHead(config)
|
|
|
|
self.init_weights()
|
|
|
|
self.post_init()
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[tuple, SequenceClassifierOutput]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.esm(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=True,
|
|
)
|
|
sequence_output = outputs[0]
|
|
logits = self.classifier(sequence_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
labels = labels.to(logits.device)
|
|
|
|
if self.config.problem_type is None:
|
|
if self.num_labels == 1:
|
|
self.config.problem_type = "regression"
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
self.config.problem_type = "single_label_classification"
|
|
else:
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
if self.config.problem_type == "regression":
|
|
loss_fct = MSELoss()
|
|
if self.num_labels == 1:
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
|
else:
|
|
loss = loss_fct(logits, labels)
|
|
elif self.config.problem_type == "single_label_classification":
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
loss_fct = BCEWithLogitsLoss()
|
|
loss = loss_fct(logits, labels)
|
|
|
|
return SequenceClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring
|
|
class EsmForTokenClassification(EsmPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.esm = EsmModel(config, add_pooling_layer=False)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
self.init_weights()
|
|
|
|
self.post_init()
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[tuple, TokenClassifierOutput]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.esm(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=True,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
sequence_output = self.dropout(sequence_output)
|
|
logits = self.classifier(sequence_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
|
|
labels = labels.to(logits.device)
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
return TokenClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
class EsmClassificationHead(nn.Module):
|
|
"""Head for sentence-level classification tasks."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
def forward(self, features, **kwargs):
|
|
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
|
x = self.dropout(x)
|
|
x = self.dense(x)
|
|
x = torch.tanh(x)
|
|
x = self.dropout(x)
|
|
x = self.out_proj(x)
|
|
return x
|
|
|
|
|
|
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
|
"""
|
|
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
|
are ignored. This is modified from fairseq's `utils.make_positions`.
|
|
|
|
Args:
|
|
x: torch.Tensor x:
|
|
|
|
Returns: torch.Tensor
|
|
"""
|
|
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
|
mask = input_ids.ne(padding_idx).int()
|
|
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
|
|
return incremental_indices.long() + padding_idx
|
|
|
|
|
|
__all__ = [
|
|
"EsmForMaskedLM",
|
|
"EsmForSequenceClassification",
|
|
"EsmForTokenClassification",
|
|
"EsmModel",
|
|
"EsmPreTrainedModel",
|
|
]
|