1359 lines
57 KiB
Python
1359 lines
57 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch ALBERT model."""
|
|
|
|
import math
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from ...activations import ACT2FN
|
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
|
|
from ...modeling_outputs import (
|
|
BaseModelOutput,
|
|
BaseModelOutputWithPooling,
|
|
MaskedLMOutput,
|
|
MultipleChoiceModelOutput,
|
|
QuestionAnsweringModelOutput,
|
|
SequenceClassifierOutput,
|
|
TokenClassifierOutput,
|
|
)
|
|
from ...modeling_utils import PreTrainedModel
|
|
from ...pytorch_utils import (
|
|
apply_chunking_to_forward,
|
|
find_pruneable_heads_and_indices,
|
|
is_torch_greater_or_equal_than_2_2,
|
|
prune_linear_layer,
|
|
)
|
|
from ...utils import ModelOutput, auto_docstring, logging
|
|
from .configuration_albert import AlbertConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
|
"""Load tf checkpoints in a pytorch model."""
|
|
try:
|
|
import re
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
logger.error(
|
|
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
|
"https://www.tensorflow.org/install/ for installation instructions."
|
|
)
|
|
raise
|
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
|
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
|
# Load weights from TF model
|
|
init_vars = tf.train.list_variables(tf_path)
|
|
names = []
|
|
arrays = []
|
|
for name, shape in init_vars:
|
|
logger.info(f"Loading TF weight {name} with shape {shape}")
|
|
array = tf.train.load_variable(tf_path, name)
|
|
names.append(name)
|
|
arrays.append(array)
|
|
|
|
for name, array in zip(names, arrays):
|
|
print(name)
|
|
|
|
for name, array in zip(names, arrays):
|
|
original_name = name
|
|
|
|
# If saved from the TF HUB module
|
|
name = name.replace("module/", "")
|
|
|
|
# Renaming and simplifying
|
|
name = name.replace("ffn_1", "ffn")
|
|
name = name.replace("bert/", "albert/")
|
|
name = name.replace("attention_1", "attention")
|
|
name = name.replace("transform/", "")
|
|
name = name.replace("LayerNorm_1", "full_layer_layer_norm")
|
|
name = name.replace("LayerNorm", "attention/LayerNorm")
|
|
name = name.replace("transformer/", "")
|
|
|
|
# The feed forward layer had an 'intermediate' step which has been abstracted away
|
|
name = name.replace("intermediate/dense/", "")
|
|
name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")
|
|
|
|
# ALBERT attention was split between self and output which have been abstracted away
|
|
name = name.replace("/output/", "/")
|
|
name = name.replace("/self/", "/")
|
|
|
|
# The pooler is a linear layer
|
|
name = name.replace("pooler/dense", "pooler")
|
|
|
|
# The classifier was simplified to predictions from cls/predictions
|
|
name = name.replace("cls/predictions", "predictions")
|
|
name = name.replace("predictions/attention", "predictions")
|
|
|
|
# Naming was changed to be more explicit
|
|
name = name.replace("embeddings/attention", "embeddings")
|
|
name = name.replace("inner_group_", "albert_layers/")
|
|
name = name.replace("group_", "albert_layer_groups/")
|
|
|
|
# Classifier
|
|
if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
|
|
name = "classifier/" + name
|
|
|
|
# No ALBERT model currently handles the next sentence prediction task
|
|
if "seq_relationship" in name:
|
|
name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
|
|
name = name.replace("weights", "weight")
|
|
|
|
name = name.split("/")
|
|
|
|
# Ignore the gradients applied by the LAMB/ADAM optimizers.
|
|
if (
|
|
"adam_m" in name
|
|
or "adam_v" in name
|
|
or "AdamWeightDecayOptimizer" in name
|
|
or "AdamWeightDecayOptimizer_1" in name
|
|
or "global_step" in name
|
|
):
|
|
logger.info(f"Skipping {'/'.join(name)}")
|
|
continue
|
|
|
|
pointer = model
|
|
for m_name in name:
|
|
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
|
scope_names = re.split(r"_(\d+)", m_name)
|
|
else:
|
|
scope_names = [m_name]
|
|
|
|
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
|
pointer = getattr(pointer, "weight")
|
|
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
|
pointer = getattr(pointer, "bias")
|
|
elif scope_names[0] == "output_weights":
|
|
pointer = getattr(pointer, "weight")
|
|
elif scope_names[0] == "squad":
|
|
pointer = getattr(pointer, "classifier")
|
|
else:
|
|
try:
|
|
pointer = getattr(pointer, scope_names[0])
|
|
except AttributeError:
|
|
logger.info(f"Skipping {'/'.join(name)}")
|
|
continue
|
|
if len(scope_names) >= 2:
|
|
num = int(scope_names[1])
|
|
pointer = pointer[num]
|
|
|
|
if m_name[-11:] == "_embeddings":
|
|
pointer = getattr(pointer, "weight")
|
|
elif m_name == "kernel":
|
|
array = np.transpose(array)
|
|
try:
|
|
if pointer.shape != array.shape:
|
|
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
|
except ValueError as e:
|
|
e.args += (pointer.shape, array.shape)
|
|
raise
|
|
print(f"Initialize PyTorch weight {name} from {original_name}")
|
|
pointer.data = torch.from_numpy(array)
|
|
|
|
return model
|
|
|
|
|
|
class AlbertEmbeddings(nn.Module):
|
|
"""
|
|
Construct the embeddings from word, position and token_type embeddings.
|
|
"""
|
|
|
|
def __init__(self, config: AlbertConfig):
|
|
super().__init__()
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
|
|
|
|
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
|
# any TensorFlow checkpoint file
|
|
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
|
self.register_buffer(
|
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
|
)
|
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
|
self.register_buffer(
|
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
|
)
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
past_key_values_length: int = 0,
|
|
) -> torch.Tensor:
|
|
if input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
else:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
seq_length = input_shape[1]
|
|
|
|
if position_ids is None:
|
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
|
|
|
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
|
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
|
# issue #5664
|
|
if token_type_ids is None:
|
|
if hasattr(self, "token_type_ids"):
|
|
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
|
token_type_ids = buffered_token_type_ids_expanded
|
|
else:
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
|
|
embeddings = inputs_embeds + token_type_embeddings
|
|
if self.position_embedding_type == "absolute":
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
embeddings += position_embeddings
|
|
embeddings = self.LayerNorm(embeddings)
|
|
embeddings = self.dropout(embeddings)
|
|
return embeddings
|
|
|
|
|
|
class AlbertAttention(nn.Module):
|
|
def __init__(self, config: AlbertConfig):
|
|
super().__init__()
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
|
raise ValueError(
|
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
|
f"heads ({config.num_attention_heads}"
|
|
)
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.hidden_size = config.hidden_size
|
|
self.attention_head_size = config.hidden_size // config.num_attention_heads
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
|
|
|
self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
|
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.pruned_heads = set()
|
|
|
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
|
|
|
def prune_heads(self, heads: list[int]) -> None:
|
|
if len(heads) == 0:
|
|
return
|
|
heads, index = find_pruneable_heads_and_indices(
|
|
heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
|
|
)
|
|
|
|
# Prune linear layers
|
|
self.query = prune_linear_layer(self.query, index)
|
|
self.key = prune_linear_layer(self.key, index)
|
|
self.value = prune_linear_layer(self.value, index)
|
|
self.dense = prune_linear_layer(self.dense, index, dim=1)
|
|
|
|
# Update hyper params and store pruned heads
|
|
self.num_attention_heads = self.num_attention_heads - len(heads)
|
|
self.all_head_size = self.attention_head_size * self.num_attention_heads
|
|
self.pruned_heads = self.pruned_heads.union(heads)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
output_attentions: bool = False,
|
|
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
|
batch_size, seq_length, _ = hidden_states.shape
|
|
query_layer = self.query(hidden_states)
|
|
key_layer = self.key(hidden_states)
|
|
value_layer = self.value(hidden_states)
|
|
query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
|
|
1, 2
|
|
)
|
|
key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
|
|
value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
|
|
1, 2
|
|
)
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
|
|
|
if attention_mask is not None:
|
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
|
attention_scores = attention_scores + attention_mask
|
|
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
|
seq_length = hidden_states.size()[1]
|
|
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
|
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
|
distance = position_ids_l - position_ids_r
|
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
|
|
|
if self.position_embedding_type == "relative_key":
|
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
|
attention_scores = attention_scores + relative_position_scores
|
|
elif self.position_embedding_type == "relative_key_query":
|
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
|
|
|
# Normalize the attention scores to probabilities.
|
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
attention_probs = self.attention_dropout(attention_probs)
|
|
|
|
# Mask heads if we want to
|
|
if head_mask is not None:
|
|
attention_probs = attention_probs * head_mask
|
|
|
|
context_layer = torch.matmul(attention_probs, value_layer)
|
|
context_layer = context_layer.transpose(2, 1).flatten(2)
|
|
|
|
projected_context_layer = self.dense(context_layer)
|
|
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
|
|
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
|
|
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
|
|
|
|
|
|
class AlbertSdpaAttention(AlbertAttention):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.dropout_prob = config.attention_probs_dropout_prob
|
|
self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
output_attentions: bool = False,
|
|
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
|
if self.position_embedding_type != "absolute" or output_attentions:
|
|
logger.warning(
|
|
"AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
|
"non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to "
|
|
"the eager attention implementation, but specifying the eager implementation will be required from "
|
|
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
|
|
'`attn_implementation="eager"` when loading the model.'
|
|
)
|
|
return super().forward(hidden_states, attention_mask, output_attentions=output_attentions)
|
|
|
|
batch_size, seq_len, _ = hidden_states.size()
|
|
query_layer = (
|
|
self.query(hidden_states)
|
|
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
|
.transpose(1, 2)
|
|
)
|
|
key_layer = (
|
|
self.key(hidden_states)
|
|
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
|
.transpose(1, 2)
|
|
)
|
|
value_layer = (
|
|
self.value(hidden_states)
|
|
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
|
.transpose(1, 2)
|
|
)
|
|
|
|
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
|
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
|
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
|
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
|
query_layer = query_layer.contiguous()
|
|
key_layer = key_layer.contiguous()
|
|
value_layer = value_layer.contiguous()
|
|
|
|
attention_output = torch.nn.functional.scaled_dot_product_attention(
|
|
query=query_layer,
|
|
key=key_layer,
|
|
value=value_layer,
|
|
attn_mask=attention_mask,
|
|
dropout_p=self.dropout_prob if self.training else 0.0,
|
|
is_causal=False,
|
|
)
|
|
|
|
attention_output = attention_output.transpose(1, 2)
|
|
attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size)
|
|
|
|
projected_context_layer = self.dense(attention_output)
|
|
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
|
|
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
|
|
return (layernormed_context_layer,)
|
|
|
|
|
|
ALBERT_ATTENTION_CLASSES = {
|
|
"eager": AlbertAttention,
|
|
"sdpa": AlbertSdpaAttention,
|
|
}
|
|
|
|
|
|
class AlbertLayer(nn.Module):
|
|
def __init__(self, config: AlbertConfig):
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
|
self.seq_len_dim = 1
|
|
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config)
|
|
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
self.activation = ACT2FN[config.hidden_act]
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
|
|
|
|
ffn_output = apply_chunking_to_forward(
|
|
self.ff_chunk,
|
|
self.chunk_size_feed_forward,
|
|
self.seq_len_dim,
|
|
attention_output[0],
|
|
)
|
|
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
|
|
|
|
return (hidden_states,) + attention_output[1:] # add attentions if we output them
|
|
|
|
def ff_chunk(self, attention_output: torch.Tensor) -> torch.Tensor:
|
|
ffn_output = self.ffn(attention_output)
|
|
ffn_output = self.activation(ffn_output)
|
|
ffn_output = self.ffn_output(ffn_output)
|
|
return ffn_output
|
|
|
|
|
|
class AlbertLayerGroup(nn.Module):
|
|
def __init__(self, config: AlbertConfig):
|
|
super().__init__()
|
|
|
|
self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
|
|
layer_hidden_states = ()
|
|
layer_attentions = ()
|
|
|
|
for layer_index, albert_layer in enumerate(self.albert_layers):
|
|
layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
|
|
hidden_states = layer_output[0]
|
|
|
|
if output_attentions:
|
|
layer_attentions = layer_attentions + (layer_output[1],)
|
|
|
|
if output_hidden_states:
|
|
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
|
|
|
outputs = (hidden_states,)
|
|
if output_hidden_states:
|
|
outputs = outputs + (layer_hidden_states,)
|
|
if output_attentions:
|
|
outputs = outputs + (layer_attentions,)
|
|
return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
|
|
|
|
|
|
class AlbertTransformer(nn.Module):
|
|
def __init__(self, config: AlbertConfig):
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
|
|
self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
return_dict: bool = True,
|
|
) -> Union[BaseModelOutput, tuple]:
|
|
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
|
|
|
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
|
all_attentions = () if output_attentions else None
|
|
|
|
head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask
|
|
|
|
for i in range(self.config.num_hidden_layers):
|
|
# Number of layers in a hidden group
|
|
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
|
|
|
|
# Index of the hidden group
|
|
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
|
|
|
layer_group_output = self.albert_layer_groups[group_idx](
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
|
|
output_attentions,
|
|
output_hidden_states,
|
|
)
|
|
hidden_states = layer_group_output[0]
|
|
|
|
if output_attentions:
|
|
all_attentions = all_attentions + layer_group_output[-1]
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
|
)
|
|
|
|
|
|
@auto_docstring
|
|
class AlbertPreTrainedModel(PreTrainedModel):
|
|
config: AlbertConfig
|
|
load_tf_weights = load_tf_weights_in_albert
|
|
base_model_prefix = "albert"
|
|
_supports_sdpa = True
|
|
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights."""
|
|
if isinstance(module, nn.Linear):
|
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.Embedding):
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
if module.padding_idx is not None:
|
|
module.weight.data[module.padding_idx].zero_()
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
elif isinstance(module, AlbertMLMHead):
|
|
module.bias.data.zero_()
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Output type of [`AlbertForPreTraining`].
|
|
"""
|
|
)
|
|
class AlbertForPreTrainingOutput(ModelOutput):
|
|
r"""
|
|
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
|
Total loss as the sum of the masked language modeling loss and the next sequence prediction
|
|
(classification) loss.
|
|
prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
sop_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
|
|
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
|
before SoftMax).
|
|
"""
|
|
|
|
loss: Optional[torch.FloatTensor] = None
|
|
prediction_logits: Optional[torch.FloatTensor] = None
|
|
sop_logits: Optional[torch.FloatTensor] = None
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
|
|
|
|
@auto_docstring
|
|
class AlbertModel(AlbertPreTrainedModel):
|
|
config: AlbertConfig
|
|
base_model_prefix = "albert"
|
|
|
|
def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True):
|
|
r"""
|
|
add_pooling_layer (bool, *optional*, defaults to `True`):
|
|
Whether to add a pooling layer
|
|
"""
|
|
super().__init__(config)
|
|
|
|
self.config = config
|
|
self.embeddings = AlbertEmbeddings(config)
|
|
self.encoder = AlbertTransformer(config)
|
|
if add_pooling_layer:
|
|
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.pooler_activation = nn.Tanh()
|
|
else:
|
|
self.pooler = None
|
|
self.pooler_activation = None
|
|
|
|
self.attn_implementation = config._attn_implementation
|
|
self.position_embedding_type = config.position_embedding_type
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self) -> nn.Embedding:
|
|
return self.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, value: nn.Embedding) -> None:
|
|
self.embeddings.word_embeddings = value
|
|
|
|
def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
|
|
"""
|
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
|
|
a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
|
|
model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
|
|
|
|
These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
|
|
while [2,3] correspond to the two inner groups of the second hidden layer.
|
|
|
|
Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more
|
|
information about head pruning
|
|
"""
|
|
for layer, heads in heads_to_prune.items():
|
|
group_idx = int(layer / self.config.inner_group_num)
|
|
inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
|
|
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[BaseModelOutputWithPooling, tuple]:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
input_shape = input_ids.size()
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
batch_size, seq_length = input_shape
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(input_shape, device=device)
|
|
if token_type_ids is None:
|
|
if hasattr(self.embeddings, "token_type_ids"):
|
|
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
|
token_type_ids = buffered_token_type_ids_expanded
|
|
else:
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
|
|
|
embedding_output = self.embeddings(
|
|
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
|
)
|
|
|
|
use_sdpa_attention_mask = (
|
|
self.attn_implementation == "sdpa"
|
|
and self.position_embedding_type == "absolute"
|
|
and head_mask is None
|
|
and not output_attentions
|
|
)
|
|
|
|
if use_sdpa_attention_mask:
|
|
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
|
attention_mask, embedding_output.dtype, tgt_len=seq_length
|
|
)
|
|
else:
|
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
|
|
|
encoder_outputs = self.encoder(
|
|
embedding_output,
|
|
extended_attention_mask,
|
|
head_mask=head_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = encoder_outputs[0]
|
|
|
|
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
|
|
|
|
if not return_dict:
|
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
|
|
|
return BaseModelOutputWithPooling(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
|
|
`sentence order prediction (classification)` head.
|
|
"""
|
|
)
|
|
class AlbertForPreTraining(AlbertPreTrainedModel):
|
|
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
|
|
|
|
def __init__(self, config: AlbertConfig):
|
|
super().__init__(config)
|
|
|
|
self.albert = AlbertModel(config)
|
|
self.predictions = AlbertMLMHead(config)
|
|
self.sop_classifier = AlbertSOPHead(config)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_output_embeddings(self) -> nn.Linear:
|
|
return self.predictions.decoder
|
|
|
|
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
|
self.predictions.decoder = new_embeddings
|
|
|
|
def get_input_embeddings(self) -> nn.Embedding:
|
|
return self.albert.embeddings.word_embeddings
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
sentence_order_label: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[AlbertForPreTrainingOutput, tuple]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
|
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
|
sentence_order_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
|
(see `input_ids` docstring) Indices should be in `[0, 1]`. `0` indicates original order (sequence A, then
|
|
sequence B), `1` indicates switched order (sequence B, then sequence A).
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, AlbertForPreTraining
|
|
>>> import torch
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
|
|
>>> model = AlbertForPreTraining.from_pretrained("albert/albert-base-v2")
|
|
|
|
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
|
|
>>> # Batch size 1
|
|
>>> outputs = model(input_ids)
|
|
|
|
>>> prediction_logits = outputs.prediction_logits
|
|
>>> sop_logits = outputs.sop_logits
|
|
```"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.albert(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output, pooled_output = outputs[:2]
|
|
|
|
prediction_scores = self.predictions(sequence_output)
|
|
sop_scores = self.sop_classifier(pooled_output)
|
|
|
|
total_loss = None
|
|
if labels is not None and sentence_order_label is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
|
sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
|
|
total_loss = masked_lm_loss + sentence_order_loss
|
|
|
|
if not return_dict:
|
|
output = (prediction_scores, sop_scores) + outputs[2:]
|
|
return ((total_loss,) + output) if total_loss is not None else output
|
|
|
|
return AlbertForPreTrainingOutput(
|
|
loss=total_loss,
|
|
prediction_logits=prediction_scores,
|
|
sop_logits=sop_scores,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
class AlbertMLMHead(nn.Module):
|
|
def __init__(self, config: AlbertConfig):
|
|
super().__init__()
|
|
|
|
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
|
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
|
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
|
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
|
|
self.activation = ACT2FN[config.hidden_act]
|
|
self.decoder.bias = self.bias
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.activation(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states)
|
|
hidden_states = self.decoder(hidden_states)
|
|
|
|
prediction_scores = hidden_states
|
|
|
|
return prediction_scores
|
|
|
|
def _tie_weights(self) -> None:
|
|
# For accelerate compatibility and to not break backward compatibility
|
|
if self.decoder.bias.device.type == "meta":
|
|
self.decoder.bias = self.bias
|
|
else:
|
|
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
|
self.bias = self.decoder.bias
|
|
|
|
|
|
class AlbertSOPHead(nn.Module):
|
|
def __init__(self, config: AlbertConfig):
|
|
super().__init__()
|
|
|
|
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
|
|
dropout_pooled_output = self.dropout(pooled_output)
|
|
logits = self.classifier(dropout_pooled_output)
|
|
return logits
|
|
|
|
|
|
@auto_docstring
|
|
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
|
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.albert = AlbertModel(config, add_pooling_layer=False)
|
|
self.predictions = AlbertMLMHead(config)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_output_embeddings(self) -> nn.Linear:
|
|
return self.predictions.decoder
|
|
|
|
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
|
self.predictions.decoder = new_embeddings
|
|
self.predictions.bias = new_embeddings.bias
|
|
|
|
def get_input_embeddings(self) -> nn.Embedding:
|
|
return self.albert.embeddings.word_embeddings
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[MaskedLMOutput, tuple]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
|
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> import torch
|
|
>>> from transformers import AutoTokenizer, AlbertForMaskedLM
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
|
|
>>> model = AlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
|
|
|
|
>>> # add mask_token
|
|
>>> inputs = tokenizer("The capital of [MASK] is Paris.", return_tensors="pt")
|
|
>>> with torch.no_grad():
|
|
... logits = model(**inputs).logits
|
|
|
|
>>> # retrieve index of [MASK]
|
|
>>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
|
|
>>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
|
|
>>> tokenizer.decode(predicted_token_id)
|
|
'france'
|
|
```
|
|
|
|
```python
|
|
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
|
|
>>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
|
|
>>> outputs = model(**inputs, labels=labels)
|
|
>>> round(outputs.loss.item(), 2)
|
|
0.81
|
|
```
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.albert(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
sequence_outputs = outputs[0]
|
|
|
|
prediction_scores = self.predictions(sequence_outputs)
|
|
|
|
masked_lm_loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (prediction_scores,) + outputs[2:]
|
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
|
|
|
return MaskedLMOutput(
|
|
loss=masked_lm_loss,
|
|
logits=prediction_scores,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
|
output) e.g. for GLUE tasks.
|
|
"""
|
|
)
|
|
class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
|
def __init__(self, config: AlbertConfig):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
self.config = config
|
|
|
|
self.albert = AlbertModel(config)
|
|
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[SequenceClassifierOutput, tuple]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.albert(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output)
|
|
logits = self.classifier(pooled_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
if self.config.problem_type is None:
|
|
if self.num_labels == 1:
|
|
self.config.problem_type = "regression"
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
self.config.problem_type = "single_label_classification"
|
|
else:
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
if self.config.problem_type == "regression":
|
|
loss_fct = MSELoss()
|
|
if self.num_labels == 1:
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
|
else:
|
|
loss = loss_fct(logits, labels)
|
|
elif self.config.problem_type == "single_label_classification":
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
loss_fct = BCEWithLogitsLoss()
|
|
loss = loss_fct(logits, labels)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return SequenceClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring
|
|
class AlbertForTokenClassification(AlbertPreTrainedModel):
|
|
def __init__(self, config: AlbertConfig):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.albert = AlbertModel(config, add_pooling_layer=False)
|
|
classifier_dropout_prob = (
|
|
config.classifier_dropout_prob
|
|
if config.classifier_dropout_prob is not None
|
|
else config.hidden_dropout_prob
|
|
)
|
|
self.dropout = nn.Dropout(classifier_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[TokenClassifierOutput, tuple]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.albert(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
sequence_output = self.dropout(sequence_output)
|
|
logits = self.classifier(sequence_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return TokenClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring
|
|
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
|
def __init__(self, config: AlbertConfig):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.albert = AlbertModel(config, add_pooling_layer=False)
|
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
start_positions: Optional[torch.LongTensor] = None,
|
|
end_positions: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[AlbertForPreTrainingOutput, tuple]:
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.albert(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
logits: torch.Tensor = self.qa_outputs(sequence_output)
|
|
start_logits, end_logits = logits.split(1, dim=-1)
|
|
start_logits = start_logits.squeeze(-1).contiguous()
|
|
end_logits = end_logits.squeeze(-1).contiguous()
|
|
|
|
total_loss = None
|
|
if start_positions is not None and end_positions is not None:
|
|
# If we are on multi-GPU, split add a dimension
|
|
if len(start_positions.size()) > 1:
|
|
start_positions = start_positions.squeeze(-1)
|
|
if len(end_positions.size()) > 1:
|
|
end_positions = end_positions.squeeze(-1)
|
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
|
ignored_index = start_logits.size(1)
|
|
start_positions = start_positions.clamp(0, ignored_index)
|
|
end_positions = end_positions.clamp(0, ignored_index)
|
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
|
start_loss = loss_fct(start_logits, start_positions)
|
|
end_loss = loss_fct(end_logits, end_positions)
|
|
total_loss = (start_loss + end_loss) / 2
|
|
|
|
if not return_dict:
|
|
output = (start_logits, end_logits) + outputs[2:]
|
|
return ((total_loss,) + output) if total_loss is not None else output
|
|
|
|
return QuestionAnsweringModelOutput(
|
|
loss=total_loss,
|
|
start_logits=start_logits,
|
|
end_logits=end_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring
|
|
class AlbertForMultipleChoice(AlbertPreTrainedModel):
|
|
def __init__(self, config: AlbertConfig):
|
|
super().__init__(config)
|
|
|
|
self.albert = AlbertModel(config)
|
|
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[AlbertForPreTrainingOutput, tuple]:
|
|
r"""
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
|
|
[`PreTrainedTokenizer.encode`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
|
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
|
1]`:
|
|
|
|
- 0 corresponds to a *sentence A* token,
|
|
- 1 corresponds to a *sentence B* token.
|
|
|
|
[What are token type IDs?](../glossary#token-type-ids)
|
|
position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
config.max_position_embeddings - 1]`.
|
|
|
|
[What are position IDs?](../glossary#position-ids)
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
|
model's internal embedding lookup matrix.
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
|
num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
|
|
*input_ids* above)
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
|
|
|
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
|
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
|
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
|
inputs_embeds = (
|
|
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
|
if inputs_embeds is not None
|
|
else None
|
|
)
|
|
outputs = self.albert(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output)
|
|
logits: torch.Tensor = self.classifier(pooled_output)
|
|
reshaped_logits = logits.view(-1, num_choices)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(reshaped_logits, labels)
|
|
|
|
if not return_dict:
|
|
output = (reshaped_logits,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return MultipleChoiceModelOutput(
|
|
loss=loss,
|
|
logits=reshaped_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"load_tf_weights_in_albert",
|
|
"AlbertPreTrainedModel",
|
|
"AlbertModel",
|
|
"AlbertForPreTraining",
|
|
"AlbertForMaskedLM",
|
|
"AlbertForSequenceClassification",
|
|
"AlbertForTokenClassification",
|
|
"AlbertForQuestionAnswering",
|
|
"AlbertForMultipleChoice",
|
|
]
|