# coding=utf-8 # Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch SEW model.""" import math import warnings from typing import Optional, Union import torch import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring from ..wav2vec2.modeling_wav2vec2 import ( Wav2Vec2Attention, Wav2Vec2EncoderLayer, Wav2Vec2FeatureEncoder, Wav2Vec2FeedForward, Wav2Vec2ForCTC, Wav2Vec2ForSequenceClassification, Wav2Vec2GroupNormConvLayer, Wav2Vec2LayerNormConvLayer, Wav2Vec2NoLayerNormConvLayer, Wav2Vec2SamePadLayer, _compute_mask_indices, ) from .configuration_sew import SEWConfig _HIDDEN_STATES_START_POSITION = 1 class SEWNoLayerNormConvLayer(Wav2Vec2NoLayerNormConvLayer): pass class SEWLayerNormConvLayer(Wav2Vec2LayerNormConvLayer): pass class SEWGroupNormConvLayer(Wav2Vec2GroupNormConvLayer): pass class SEWPositionalConvEmbedding(nn.Module): def __init__(self, config): super().__init__() self.conv = nn.Conv1d( config.hidden_size, config.hidden_size, kernel_size=config.num_conv_pos_embeddings, padding=config.num_conv_pos_embeddings // 2, groups=config.num_conv_pos_embedding_groups, stride=config.squeeze_factor, ) weight_norm = nn.utils.weight_norm if hasattr(nn.utils.parametrizations, "weight_norm"): weight_norm = nn.utils.parametrizations.weight_norm if is_deepspeed_zero3_enabled(): import deepspeed with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): self.conv = weight_norm(self.conv, name="weight", dim=2) if hasattr(self.conv, "parametrizations"): weight_g = self.conv.parametrizations.weight.original0 weight_v = self.conv.parametrizations.weight.original1 else: weight_g = self.conv.weight_g weight_v = self.conv.weight_v deepspeed.zero.register_external_parameter(self, weight_v) deepspeed.zero.register_external_parameter(self, weight_g) else: self.conv = weight_norm(self.conv, name="weight", dim=2) self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings) self.activation = ACT2FN[config.feat_extract_activation] def forward(self, hidden_states): hidden_states = self.conv(hidden_states) hidden_states = self.padding(hidden_states) hidden_states = self.activation(hidden_states) return hidden_states class SEWSamePadLayer(Wav2Vec2SamePadLayer): pass class SEWUpsampling(nn.Module): def __init__(self, config): super().__init__() self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor) self.activation = ACT2FN[config.feat_extract_activation] self.squeeze_factor = config.squeeze_factor def forward(self, hidden_states): hidden_states = self.projection(hidden_states) hidden_states = self.activation(hidden_states) if self.squeeze_factor > 1: # transform embedding channels to sequence length bsz, src_len, src_embed_dim = hidden_states.size() tgt_len = src_len * self.squeeze_factor tgt_embed_dim = src_embed_dim // self.squeeze_factor hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim) hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim) return hidden_states class SEWFeatureEncoder(Wav2Vec2FeatureEncoder): pass class SEWFeatureExtractor(SEWFeatureEncoder): def __init__(self, config): super().__init__(config) warnings.warn( f"The class `{self.__class__.__name__}` has been depreciated " "and will be removed in Transformers v5. " f"Use `{self.__class__.__bases__[0].__name__}` instead.", FutureWarning, ) class SEWAttention(Wav2Vec2Attention): pass class SEWFeedForward(Wav2Vec2FeedForward): pass class SEWEncoderLayer(Wav2Vec2EncoderLayer): pass class SEWEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.pos_conv_embed = SEWPositionalConvEmbedding(config) self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.upsample = SEWUpsampling(config) self.gradient_checkpointing = False def forward( self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if attention_mask is not None: expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) if self.config._attn_implementation == "flash_attention_2": # make sure padded tokens output 0 hidden_states[~expand_attention_mask] = 0.0 # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: # make sure padded tokens output 0 hidden_states[~expand_attention_mask] = 0.0 input_lengths = (attention_mask.long()).sum(-1) # apply pooling formula to get real output_lengths output_lengths = input_lengths // self.config.squeeze_factor max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor attention_ids = ( torch.arange(0, max_encoder_length, device=output_lengths.device) .view(1, -1) .expand(output_lengths.shape[0], -1) ) attention_mask = (attention_ids < output_lengths.view(-1, 1)).long() # extend attention_mask attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) n_input_timesteps = hidden_states.shape[1] hidden_states = hidden_states.transpose(1, 2) position_embeddings = self.pos_conv_embed(hidden_states) pooled_hidden_states = self.pool(hidden_states) min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1)) hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length] hidden_states = hidden_states.transpose(1, 2) hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync layer_outputs = layer( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions ) hidden_states = layer_outputs[0] if skip_the_layer: layer_outputs = (None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) hidden_states = self.upsample(hidden_states) if hidden_states.shape[1] < n_input_timesteps: hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1])) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @auto_docstring class SEWPreTrainedModel(PreTrainedModel): config: SEWConfig base_model_prefix = "sew" main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = False # needs a proper look into the mask creation def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SEWPositionalConvEmbedding): nn.init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) nn.init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): if is_deepspeed_zero3_enabled(): import deepspeed if hasattr(module, "weight_v") and hasattr(module, "weight_g"): with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): nn.init.kaiming_normal_(module.weight.data) else: with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): nn.init.kaiming_normal_(module.weight.data) else: nn.init.kaiming_normal_(module.weight.data) if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Computes the output length of the convolutional layers """ def _conv_out_length(input_length, kernel_size, stride): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) return input_lengths def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros( (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device ) # these two operations makes sure that all values before the output lengths idxs are attended to attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask @auto_docstring class SEWModel(SEWPreTrainedModel): def __init__(self, config: SEWConfig): super().__init__(config) self.config = config self.feature_extractor = SEWFeatureEncoder(config) self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) self.project_features = config.conv_dim[-1] != config.hidden_size if self.project_features: self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size) self.feature_dropout = nn.Dropout(config.feat_proj_dropout) if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) self.encoder = SEWEncoder(config) # Initialize weights and apply final processing self.post_init() # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states def _mask_hidden_states( self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): """ Masks extracted features along time axis and/or along feature axis according to [SpecAugment](https://huggingface.co/papers/1904.08779). """ # `config.apply_spec_augment` can set masking to False if not getattr(self.config, "apply_spec_augment", True): return hidden_states # generate indices & apply SpecAugment along time axis batch_size, sequence_length, hidden_size = hidden_states.size() if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), mask_prob=self.config.mask_time_prob, mask_length=self.config.mask_time_length, attention_mask=attention_mask, min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis mask_feature_indices = _compute_mask_indices( (batch_size, hidden_size), mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, min_masks=self.config.mask_feature_min_masks, ) mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) hidden_states[mask_feature_indices] = 0 return hidden_states @auto_docstring def forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutput]: r""" mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict masked extracted features in *config.proj_codevector_dim* space. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict extract_features = self.feature_extractor(input_values) extract_features = extract_features.transpose(1, 2) extract_features = self.layer_norm(extract_features) if self.project_features: extract_features = self.feature_projection(extract_features) hidden_states = self.feature_dropout(extract_features) if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if not return_dict: return (hidden_states,) + encoder_outputs[1:] return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class SEWForCTC(Wav2Vec2ForCTC): pass class SEWForSequenceClassification(Wav2Vec2ForSequenceClassification): pass __all__ = ["SEWForCTC", "SEWForSequenceClassification", "SEWModel", "SEWPreTrainedModel"]