team-10/venv/Lib/site-packages/transformers/models/sew/modular_sew.py
2025-08-02 02:00:33 +02:00

461 lines
18 KiB
Python

# coding=utf-8
# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch SEW model."""
import math
import warnings
from typing import Optional, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring
from ..wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Attention,
Wav2Vec2EncoderLayer,
Wav2Vec2FeatureEncoder,
Wav2Vec2FeedForward,
Wav2Vec2ForCTC,
Wav2Vec2ForSequenceClassification,
Wav2Vec2GroupNormConvLayer,
Wav2Vec2LayerNormConvLayer,
Wav2Vec2NoLayerNormConvLayer,
Wav2Vec2SamePadLayer,
_compute_mask_indices,
)
from .configuration_sew import SEWConfig
_HIDDEN_STATES_START_POSITION = 1
class SEWNoLayerNormConvLayer(Wav2Vec2NoLayerNormConvLayer):
pass
class SEWLayerNormConvLayer(Wav2Vec2LayerNormConvLayer):
pass
class SEWGroupNormConvLayer(Wav2Vec2GroupNormConvLayer):
pass
class SEWPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = nn.Conv1d(
config.hidden_size,
config.hidden_size,
kernel_size=config.num_conv_pos_embeddings,
padding=config.num_conv_pos_embeddings // 2,
groups=config.num_conv_pos_embedding_groups,
stride=config.squeeze_factor,
)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class SEWSamePadLayer(Wav2Vec2SamePadLayer):
pass
class SEWUpsampling(nn.Module):
def __init__(self, config):
super().__init__()
self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor)
self.activation = ACT2FN[config.feat_extract_activation]
self.squeeze_factor = config.squeeze_factor
def forward(self, hidden_states):
hidden_states = self.projection(hidden_states)
hidden_states = self.activation(hidden_states)
if self.squeeze_factor > 1:
# transform embedding channels to sequence length
bsz, src_len, src_embed_dim = hidden_states.size()
tgt_len = src_len * self.squeeze_factor
tgt_embed_dim = src_embed_dim // self.squeeze_factor
hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim)
hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim)
return hidden_states
class SEWFeatureEncoder(Wav2Vec2FeatureEncoder):
pass
class SEWFeatureExtractor(SEWFeatureEncoder):
def __init__(self, config):
super().__init__(config)
warnings.warn(
f"The class `{self.__class__.__name__}` has been depreciated "
"and will be removed in Transformers v5. "
f"Use `{self.__class__.__bases__[0].__name__}` instead.",
FutureWarning,
)
class SEWAttention(Wav2Vec2Attention):
pass
class SEWFeedForward(Wav2Vec2FeedForward):
pass
class SEWEncoderLayer(Wav2Vec2EncoderLayer):
pass
class SEWEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.pos_conv_embed = SEWPositionalConvEmbedding(config)
self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.upsample = SEWUpsampling(config)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
if self.config._attn_implementation == "flash_attention_2":
# make sure padded tokens output 0
hidden_states[~expand_attention_mask] = 0.0
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# make sure padded tokens output 0
hidden_states[~expand_attention_mask] = 0.0
input_lengths = (attention_mask.long()).sum(-1)
# apply pooling formula to get real output_lengths
output_lengths = input_lengths // self.config.squeeze_factor
max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor
attention_ids = (
torch.arange(0, max_encoder_length, device=output_lengths.device)
.view(1, -1)
.expand(output_lengths.shape[0], -1)
)
attention_mask = (attention_ids < output_lengths.view(-1, 1)).long()
# extend attention_mask
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
n_input_timesteps = hidden_states.shape[1]
hidden_states = hidden_states.transpose(1, 2)
position_embeddings = self.pos_conv_embed(hidden_states)
pooled_hidden_states = self.pool(hidden_states)
min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1))
hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length]
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.upsample(hidden_states)
if hidden_states.shape[1] < n_input_timesteps:
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1]))
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@auto_docstring
class SEWPreTrainedModel(PreTrainedModel):
config: SEWConfig
base_model_prefix = "sew"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = False # needs a proper look into the mask creation
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SEWPositionalConvEmbedding):
nn.init.normal_(
module.conv.weight,
mean=0,
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
)
nn.init.constant_(module.conv.bias, 0)
elif isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
if is_deepspeed_zero3_enabled():
import deepspeed
if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
nn.init.kaiming_normal_(module.weight.data)
else:
with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
nn.init.kaiming_normal_(module.weight.data)
else:
nn.init.kaiming_normal_(module.weight.data)
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
@auto_docstring
class SEWModel(SEWPreTrainedModel):
def __init__(self, config: SEWConfig):
super().__init__(config)
self.config = config
self.feature_extractor = SEWFeatureEncoder(config)
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
self.project_features = config.conv_dim[-1] != config.hidden_size
if self.project_features:
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
self.encoder = SEWEncoder(config)
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
mask_time_indices: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://huggingface.co/papers/1904.08779).
"""
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return hidden_states
# generate indices & apply SpecAugment along time axis
batch_size, sequence_length, hidden_size = hidden_states.size()
if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=self.config.mask_time_min_masks,
)
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
min_masks=self.config.mask_feature_min_masks,
)
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
hidden_states[mask_feature_indices] = 0
return hidden_states
@auto_docstring
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutput]:
r"""
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
masked extracted features in *config.proj_codevector_dim* space.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
extract_features = self.layer_norm(extract_features)
if self.project_features:
extract_features = self.feature_projection(extract_features)
hidden_states = self.feature_dropout(extract_features)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if not return_dict:
return (hidden_states,) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class SEWForCTC(Wav2Vec2ForCTC):
pass
class SEWForSequenceClassification(Wav2Vec2ForSequenceClassification):
pass
__all__ = ["SEWForCTC", "SEWForSequenceClassification", "SEWModel", "SEWPreTrainedModel"]