import math from typing import Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import logging from ..wav2vec2.modeling_wav2vec2 import ( Wav2Vec2FeatureProjection, Wav2Vec2FeedForward, Wav2Vec2ForAudioFrameClassification, Wav2Vec2ForCTC, Wav2Vec2ForSequenceClassification, Wav2Vec2ForXVector, Wav2Vec2Model, Wav2Vec2PositionalConvEmbedding, Wav2Vec2PreTrainedModel, ) from .configuration_wavlm import WavLMConfig logger = logging.get_logger(__name__) class WavLMPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding): pass class WavLMFeatureProjection(Wav2Vec2FeatureProjection): pass class WavLMAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, num_buckets: int = 320, max_distance: int = 800, has_relative_position_bias: bool = True, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.num_buckets = num_buckets self.max_distance = max_distance self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1)) self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8) if has_relative_position_bias: self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_bias: Optional[torch.Tensor] = None, output_attentions: bool = False, index=0, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Attention layer with relative attention""" bsz, tgt_len, _ = hidden_states.size() # first pass of attention layer creates position bias if position_bias is None: position_bias = self.compute_bias(tgt_len, tgt_len) position_bias = ( position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, tgt_len) ) # Compute relative position bias: # 1) get reshape hidden_states gated_hidden_states = hidden_states.view(hidden_states.shape[:-1] + (self.num_heads, -1)) gated_hidden_states = gated_hidden_states.permute(0, 2, 1, 3) # 2) project hidden states relative_position_proj = self.gru_rel_pos_linear(gated_hidden_states) relative_position_proj = relative_position_proj.view(gated_hidden_states.shape[:-1] + (2, 4)).sum(-1) # 3) compute gate for position bias from projected hidden states gate_a, gate_b = torch.sigmoid(relative_position_proj).chunk(2, dim=-1) gate_output = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0 # 4) apply gate to position bias to compute gated position_bias gated_position_bias = gate_output.view(bsz * self.num_heads, -1, 1) * position_bias gated_position_bias = gated_position_bias.view((-1, tgt_len, tgt_len)) attn_output, attn_weights = self.torch_multi_head_self_attention( hidden_states, attention_mask, gated_position_bias, output_attentions ) return attn_output, attn_weights, position_bias def torch_multi_head_self_attention( self, hidden_states: torch.FloatTensor, attention_mask: Union[torch.LongTensor, torch.BoolTensor], gated_position_bias: torch.FloatTensor, output_attentions: bool, ) -> (torch.FloatTensor, torch.FloatTensor): """simple wrapper around torch's multi_head_attention_forward function""" # self-attention assumes q = k = v query = key = value = hidden_states.transpose(0, 1) key_padding_mask = attention_mask.ne(1) if attention_mask is not None else None # disable bias and add_zero_attn bias_k = bias_v = None add_zero_attn = False # PyTorch 1.3.0 has F.multi_head_attention_forward defined # so no problem with backwards compatibility attn_output, attn_weights = F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), bias_k, bias_v, add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, self.training, key_padding_mask, output_attentions, gated_position_bias, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, ) # [Seq_Len, Batch Size, ...] -> [Batch Size, Seq_Len, ...] attn_output = attn_output.transpose(0, 1) if attn_weights is not None: # IMPORTANT: Attention weights are averaged weights # here which should not be the case. This is an open issue # on PyTorch: https://github.com/pytorch/pytorch/issues/32590 attn_weights = attn_weights[:, None].broadcast_to( attn_weights.shape[:1] + (self.num_heads,) + attn_weights.shape[1:] ) return attn_output, attn_weights def compute_bias(self, query_length: int, key_length: int) -> torch.FloatTensor: context_position = torch.arange(query_length, dtype=torch.long)[:, None] memory_position = torch.arange(key_length, dtype=torch.long)[None, :] relative_position = memory_position - context_position relative_position_bucket = self._relative_positions_bucket(relative_position) relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device) values = self.rel_attn_embed(relative_position_bucket) values = values.permute([2, 0, 1]) return values def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> torch.FloatTensor: num_buckets = self.num_buckets // 2 relative_buckets = (relative_positions > 0).to(torch.long) * num_buckets relative_positions = torch.abs(relative_positions) max_exact = num_buckets // 2 is_small = relative_positions < max_exact relative_positions_if_large = torch.log(relative_positions.float() / max_exact) relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact) relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact) relative_position_if_large = (max_exact + relative_positions_if_large).to(torch.long) relative_position_if_large = torch.min( relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) ) relative_buckets += torch.where(is_small, relative_positions, relative_position_if_large) return relative_buckets class WavLMFeedForward(Wav2Vec2FeedForward): pass class WavLMEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): super().__init__() self.attention = WavLMAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, num_buckets=config.num_buckets, max_distance=config.max_bucket_distance, has_relative_position_bias=has_relative_position_bias, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = WavLMFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0): attn_residual = hidden_states hidden_states, attn_weights, position_bias = self.attention( hidden_states, attention_mask=attention_mask, position_bias=position_bias, output_attentions=output_attentions, index=index, ) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states, position_bias) if output_attentions: outputs += (attn_weights,) return outputs class WavLMEncoderLayerStableLayerNorm(GradientCheckpointingLayer): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): super().__init__() self.attention = WavLMAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, num_buckets=config.num_buckets, max_distance=config.max_bucket_distance, has_relative_position_bias=has_relative_position_bias, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = WavLMFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False): attn_residual = hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states, attn_weights, position_bias = self.attention( hidden_states, attention_mask=attention_mask, position_bias=position_bias, output_attentions=output_attentions, ) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) outputs = (hidden_states, position_bias) if output_attentions: outputs += (attn_weights,) return outputs class WavLMEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.pos_conv_embed = WavLMPositionalConvEmbedding(config) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList( [WavLMEncoderLayer(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False def forward( self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if attention_mask is not None: # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) position_bias = None for i, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) dropout_probability = torch.rand([]) skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync layer_outputs = layer( hidden_states, attention_mask=attention_mask, position_bias=position_bias, output_attentions=output_attentions, index=i, ) hidden_states, position_bias = layer_outputs[:2] if skip_the_layer: layer_outputs = (None, None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class WavLMEncoderStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() self.config = config self.pos_conv_embed = WavLMPositionalConvEmbedding(config) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList( [ WavLMEncoderLayerStableLayerNorm(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers) ] ) self.gradient_checkpointing = False def forward( self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if attention_mask is not None: # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.dropout(hidden_states) synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) position_bias = None for i, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) dropout_probability = torch.rand([]) skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication layer_outputs = layer( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, position_bias=position_bias, ) hidden_states, position_bias = layer_outputs[:2] if skip_the_layer: layer_outputs = (None, None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[2],) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions ) class WavLMGumbelVectorQuantizer(nn.Module): """ Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information. """ def __init__(self, config): super().__init__() self.num_groups = config.num_codevector_groups self.num_vars = config.num_codevectors_per_group if config.codevector_dim % self.num_groups != 0: raise ValueError( f"`config.codevector_dim {config.codevector_dim} must be divisible" f" by `config.num_codevector_groups` {self.num_groups} " "for concatenation." ) # storage for codebook variables (codewords) self.codevectors = nn.Parameter( torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups) ) self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars) # can be decayed for training self.temperature = 2 @staticmethod def _compute_perplexity(probs): marginal_probs = probs.mean(dim=0) perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() return perplexity def forward(self, hidden_states): batch_size, sequence_length, hidden_size = hidden_states.shape # project to codevector dim hidden_states = self.weight_proj(hidden_states) hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) if self.training: # sample code vector probs via gumbel in differentiateable way codevector_probs = nn.functional.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True) codevector_probs = codevector_probs.type_as(hidden_states) # compute perplexity codevector_soft_dist = torch.softmax( hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 ) perplexity = self._compute_perplexity(codevector_soft_dist) else: # take argmax in non-differentiable way # comptute hard codevector distribution (one hot) codevector_idx = hidden_states.argmax(dim=-1) codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( -1, codevector_idx.view(-1, 1), 1.0 ) codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) perplexity = self._compute_perplexity(codevector_probs) codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) # use probs to retrieve codevectors codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1) return codevectors, perplexity class WavLMPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): config: WavLMConfig base_model_prefix = "wavlm" main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn = False _supports_sdpa = False _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, WavLMGumbelVectorQuantizer): module.weight_proj.weight.data.normal_(mean=0.0, std=1) module.weight_proj.bias.data.zero_() nn.init.uniform_(module.codevectors) elif isinstance(module, WavLMPositionalConvEmbedding): nn.init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) nn.init.constant_(module.conv.bias, 0) elif isinstance(module, WavLMFeatureProjection): k = math.sqrt(1 / module.projection.in_features) nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) def _get_adapters(self): raise AttributeError("Not needed for WavLM") def init_adapter_layers(self): raise AttributeError("Not needed for WavLM") def load_adapter(self): raise AttributeError("Not needed for WavLM") WavLMBaseModelOutput = Wav2Vec2BaseModelOutput class WavLMModel(Wav2Vec2Model): pass class WavLMForCTC(Wav2Vec2ForCTC): pass class WavLMForSequenceClassification(Wav2Vec2ForSequenceClassification): pass class WavLMForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification): pass class WavLMForXVector(Wav2Vec2ForXVector): pass __all__ = [ "WavLMForAudioFrameClassification", "WavLMForCTC", "WavLMForSequenceClassification", "WavLMForXVector", "WavLMModel", "WavLMPreTrainedModel", ]