588 lines
23 KiB
Python
588 lines
23 KiB
Python
import math
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
from ...integrations.fsdp import is_fsdp_managed_module
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput
|
|
from ...modeling_utils import PreTrainedModel
|
|
from ...utils import logging
|
|
from ..wav2vec2.modeling_wav2vec2 import (
|
|
Wav2Vec2FeatureProjection,
|
|
Wav2Vec2FeedForward,
|
|
Wav2Vec2ForAudioFrameClassification,
|
|
Wav2Vec2ForCTC,
|
|
Wav2Vec2ForSequenceClassification,
|
|
Wav2Vec2ForXVector,
|
|
Wav2Vec2Model,
|
|
Wav2Vec2PositionalConvEmbedding,
|
|
Wav2Vec2PreTrainedModel,
|
|
)
|
|
from .configuration_wavlm import WavLMConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class WavLMPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding):
|
|
pass
|
|
|
|
|
|
class WavLMFeatureProjection(Wav2Vec2FeatureProjection):
|
|
pass
|
|
|
|
|
|
class WavLMAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
dropout: float = 0.0,
|
|
num_buckets: int = 320,
|
|
max_distance: int = 800,
|
|
has_relative_position_bias: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.head_dim = embed_dim // num_heads
|
|
|
|
if (self.head_dim * num_heads) != self.embed_dim:
|
|
raise ValueError(
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
|
f" and `num_heads`: {num_heads})."
|
|
)
|
|
self.scaling = self.head_dim**-0.5
|
|
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
|
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
|
|
|
self.num_buckets = num_buckets
|
|
self.max_distance = max_distance
|
|
|
|
self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1))
|
|
self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
|
|
|
|
if has_relative_position_bias:
|
|
self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_bias: Optional[torch.Tensor] = None,
|
|
output_attentions: bool = False,
|
|
index=0,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
"""Attention layer with relative attention"""
|
|
bsz, tgt_len, _ = hidden_states.size()
|
|
|
|
# first pass of attention layer creates position bias
|
|
if position_bias is None:
|
|
position_bias = self.compute_bias(tgt_len, tgt_len)
|
|
position_bias = (
|
|
position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, tgt_len)
|
|
)
|
|
|
|
# Compute relative position bias:
|
|
# 1) get reshape hidden_states
|
|
gated_hidden_states = hidden_states.view(hidden_states.shape[:-1] + (self.num_heads, -1))
|
|
gated_hidden_states = gated_hidden_states.permute(0, 2, 1, 3)
|
|
|
|
# 2) project hidden states
|
|
relative_position_proj = self.gru_rel_pos_linear(gated_hidden_states)
|
|
relative_position_proj = relative_position_proj.view(gated_hidden_states.shape[:-1] + (2, 4)).sum(-1)
|
|
|
|
# 3) compute gate for position bias from projected hidden states
|
|
gate_a, gate_b = torch.sigmoid(relative_position_proj).chunk(2, dim=-1)
|
|
gate_output = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
|
|
|
|
# 4) apply gate to position bias to compute gated position_bias
|
|
gated_position_bias = gate_output.view(bsz * self.num_heads, -1, 1) * position_bias
|
|
gated_position_bias = gated_position_bias.view((-1, tgt_len, tgt_len))
|
|
|
|
attn_output, attn_weights = self.torch_multi_head_self_attention(
|
|
hidden_states, attention_mask, gated_position_bias, output_attentions
|
|
)
|
|
|
|
return attn_output, attn_weights, position_bias
|
|
|
|
def torch_multi_head_self_attention(
|
|
self,
|
|
hidden_states: torch.FloatTensor,
|
|
attention_mask: Union[torch.LongTensor, torch.BoolTensor],
|
|
gated_position_bias: torch.FloatTensor,
|
|
output_attentions: bool,
|
|
) -> (torch.FloatTensor, torch.FloatTensor):
|
|
"""simple wrapper around torch's multi_head_attention_forward function"""
|
|
# self-attention assumes q = k = v
|
|
query = key = value = hidden_states.transpose(0, 1)
|
|
key_padding_mask = attention_mask.ne(1) if attention_mask is not None else None
|
|
|
|
# disable bias and add_zero_attn
|
|
bias_k = bias_v = None
|
|
add_zero_attn = False
|
|
|
|
# PyTorch 1.3.0 has F.multi_head_attention_forward defined
|
|
# so no problem with backwards compatibility
|
|
attn_output, attn_weights = F.multi_head_attention_forward(
|
|
query,
|
|
key,
|
|
value,
|
|
self.embed_dim,
|
|
self.num_heads,
|
|
torch.empty([0]),
|
|
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
|
bias_k,
|
|
bias_v,
|
|
add_zero_attn,
|
|
self.dropout,
|
|
self.out_proj.weight,
|
|
self.out_proj.bias,
|
|
self.training,
|
|
key_padding_mask,
|
|
output_attentions,
|
|
gated_position_bias,
|
|
use_separate_proj_weight=True,
|
|
q_proj_weight=self.q_proj.weight,
|
|
k_proj_weight=self.k_proj.weight,
|
|
v_proj_weight=self.v_proj.weight,
|
|
)
|
|
|
|
# [Seq_Len, Batch Size, ...] -> [Batch Size, Seq_Len, ...]
|
|
attn_output = attn_output.transpose(0, 1)
|
|
|
|
if attn_weights is not None:
|
|
# IMPORTANT: Attention weights are averaged weights
|
|
# here which should not be the case. This is an open issue
|
|
# on PyTorch: https://github.com/pytorch/pytorch/issues/32590
|
|
attn_weights = attn_weights[:, None].broadcast_to(
|
|
attn_weights.shape[:1] + (self.num_heads,) + attn_weights.shape[1:]
|
|
)
|
|
|
|
return attn_output, attn_weights
|
|
|
|
def compute_bias(self, query_length: int, key_length: int) -> torch.FloatTensor:
|
|
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
|
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
|
relative_position = memory_position - context_position
|
|
relative_position_bucket = self._relative_positions_bucket(relative_position)
|
|
relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
|
|
values = self.rel_attn_embed(relative_position_bucket)
|
|
values = values.permute([2, 0, 1])
|
|
return values
|
|
|
|
def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> torch.FloatTensor:
|
|
num_buckets = self.num_buckets // 2
|
|
|
|
relative_buckets = (relative_positions > 0).to(torch.long) * num_buckets
|
|
relative_positions = torch.abs(relative_positions)
|
|
|
|
max_exact = num_buckets // 2
|
|
is_small = relative_positions < max_exact
|
|
|
|
relative_positions_if_large = torch.log(relative_positions.float() / max_exact)
|
|
relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact)
|
|
relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact)
|
|
relative_position_if_large = (max_exact + relative_positions_if_large).to(torch.long)
|
|
relative_position_if_large = torch.min(
|
|
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
|
|
)
|
|
|
|
relative_buckets += torch.where(is_small, relative_positions, relative_position_if_large)
|
|
return relative_buckets
|
|
|
|
|
|
class WavLMFeedForward(Wav2Vec2FeedForward):
|
|
pass
|
|
|
|
|
|
class WavLMEncoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
|
|
super().__init__()
|
|
self.attention = WavLMAttention(
|
|
embed_dim=config.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
dropout=config.attention_dropout,
|
|
num_buckets=config.num_buckets,
|
|
max_distance=config.max_bucket_distance,
|
|
has_relative_position_bias=has_relative_position_bias,
|
|
)
|
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.feed_forward = WavLMFeedForward(config)
|
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
|
|
attn_residual = hidden_states
|
|
hidden_states, attn_weights, position_bias = self.attention(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_bias=position_bias,
|
|
output_attentions=output_attentions,
|
|
index=index,
|
|
)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = attn_residual + hidden_states
|
|
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
|
|
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
outputs = (hidden_states, position_bias)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
|
|
return outputs
|
|
|
|
|
|
class WavLMEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
|
|
def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
|
|
super().__init__()
|
|
self.attention = WavLMAttention(
|
|
embed_dim=config.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
dropout=config.attention_dropout,
|
|
num_buckets=config.num_buckets,
|
|
max_distance=config.max_bucket_distance,
|
|
has_relative_position_bias=has_relative_position_bias,
|
|
)
|
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.feed_forward = WavLMFeedForward(config)
|
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
|
|
attn_residual = hidden_states
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
hidden_states, attn_weights, position_bias = self.attention(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_bias=position_bias,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = attn_residual + hidden_states
|
|
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
|
|
|
outputs = (hidden_states, position_bias)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
|
|
return outputs
|
|
|
|
|
|
class WavLMEncoder(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
self.layers = nn.ModuleList(
|
|
[WavLMEncoderLayer(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers)]
|
|
)
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
output_attentions=False,
|
|
output_hidden_states=False,
|
|
return_dict=True,
|
|
):
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attentions = () if output_attentions else None
|
|
|
|
if attention_mask is not None:
|
|
# make sure padded tokens output 0
|
|
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
|
hidden_states[~expand_attention_mask] = 0
|
|
|
|
position_embeddings = self.pos_conv_embed(hidden_states)
|
|
hidden_states = hidden_states + position_embeddings
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
|
|
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
|
position_bias = None
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
|
dropout_probability = torch.rand([])
|
|
|
|
skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
|
|
if not skip_the_layer or synced_gpus:
|
|
# under fsdp or deepspeed zero3 all gpus must run in sync
|
|
layer_outputs = layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_bias=position_bias,
|
|
output_attentions=output_attentions,
|
|
index=i,
|
|
)
|
|
|
|
hidden_states, position_bias = layer_outputs[:2]
|
|
|
|
if skip_the_layer:
|
|
layer_outputs = (None, None, None)
|
|
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (layer_outputs[2],)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
)
|
|
|
|
|
|
class WavLMEncoderStableLayerNorm(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
WavLMEncoderLayerStableLayerNorm(config, has_relative_position_bias=(i == 0))
|
|
for i in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
output_attentions=False,
|
|
output_hidden_states=False,
|
|
return_dict=True,
|
|
):
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attentions = () if output_attentions else None
|
|
|
|
if attention_mask is not None:
|
|
# make sure padded tokens are not attended to
|
|
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
|
hidden_states[~expand_attention_mask] = 0
|
|
|
|
position_embeddings = self.pos_conv_embed(hidden_states)
|
|
hidden_states = hidden_states + position_embeddings
|
|
hidden_states = self.dropout(hidden_states)
|
|
|
|
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
|
position_bias = None
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
|
dropout_probability = torch.rand([])
|
|
|
|
skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
|
|
if not skip_the_layer or synced_gpus:
|
|
# under fsdp or deepspeed zero3 all gpus must run in sync
|
|
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
|
|
layer_outputs = layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
position_bias=position_bias,
|
|
)
|
|
hidden_states, position_bias = layer_outputs[:2]
|
|
|
|
if skip_the_layer:
|
|
layer_outputs = (None, None, None)
|
|
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (layer_outputs[2],)
|
|
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
|
|
)
|
|
|
|
|
|
class WavLMGumbelVectorQuantizer(nn.Module):
|
|
"""
|
|
Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
|
|
GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.num_groups = config.num_codevector_groups
|
|
self.num_vars = config.num_codevectors_per_group
|
|
|
|
if config.codevector_dim % self.num_groups != 0:
|
|
raise ValueError(
|
|
f"`config.codevector_dim {config.codevector_dim} must be divisible"
|
|
f" by `config.num_codevector_groups` {self.num_groups} "
|
|
"for concatenation."
|
|
)
|
|
|
|
# storage for codebook variables (codewords)
|
|
self.codevectors = nn.Parameter(
|
|
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
|
|
)
|
|
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
|
|
|
|
# can be decayed for training
|
|
self.temperature = 2
|
|
|
|
@staticmethod
|
|
def _compute_perplexity(probs):
|
|
marginal_probs = probs.mean(dim=0)
|
|
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
|
|
return perplexity
|
|
|
|
def forward(self, hidden_states):
|
|
batch_size, sequence_length, hidden_size = hidden_states.shape
|
|
|
|
# project to codevector dim
|
|
hidden_states = self.weight_proj(hidden_states)
|
|
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
|
|
|
if self.training:
|
|
# sample code vector probs via gumbel in differentiateable way
|
|
codevector_probs = nn.functional.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True)
|
|
codevector_probs = codevector_probs.type_as(hidden_states)
|
|
|
|
# compute perplexity
|
|
codevector_soft_dist = torch.softmax(
|
|
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
|
|
)
|
|
perplexity = self._compute_perplexity(codevector_soft_dist)
|
|
else:
|
|
# take argmax in non-differentiable way
|
|
# comptute hard codevector distribution (one hot)
|
|
codevector_idx = hidden_states.argmax(dim=-1)
|
|
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
|
|
-1, codevector_idx.view(-1, 1), 1.0
|
|
)
|
|
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
|
|
|
perplexity = self._compute_perplexity(codevector_probs)
|
|
|
|
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
|
|
# use probs to retrieve codevectors
|
|
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
|
|
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
|
|
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
|
|
|
|
return codevectors, perplexity
|
|
|
|
|
|
class WavLMPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
|
|
config: WavLMConfig
|
|
base_model_prefix = "wavlm"
|
|
main_input_name = "input_values"
|
|
supports_gradient_checkpointing = True
|
|
_supports_flash_attn = False
|
|
_supports_sdpa = False
|
|
_supports_flex_attn = False
|
|
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights"""
|
|
# gumbel softmax requires special init
|
|
if isinstance(module, WavLMGumbelVectorQuantizer):
|
|
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
|
module.weight_proj.bias.data.zero_()
|
|
nn.init.uniform_(module.codevectors)
|
|
elif isinstance(module, WavLMPositionalConvEmbedding):
|
|
nn.init.normal_(
|
|
module.conv.weight,
|
|
mean=0,
|
|
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
|
|
)
|
|
nn.init.constant_(module.conv.bias, 0)
|
|
elif isinstance(module, WavLMFeatureProjection):
|
|
k = math.sqrt(1 / module.projection.in_features)
|
|
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
|
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
|
elif isinstance(module, nn.Linear):
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
elif isinstance(module, nn.Conv1d):
|
|
nn.init.kaiming_normal_(module.weight)
|
|
|
|
if module.bias is not None:
|
|
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
|
nn.init.uniform_(module.bias, a=-k, b=k)
|
|
|
|
def _get_adapters(self):
|
|
raise AttributeError("Not needed for WavLM")
|
|
|
|
def init_adapter_layers(self):
|
|
raise AttributeError("Not needed for WavLM")
|
|
|
|
def load_adapter(self):
|
|
raise AttributeError("Not needed for WavLM")
|
|
|
|
|
|
WavLMBaseModelOutput = Wav2Vec2BaseModelOutput
|
|
|
|
|
|
class WavLMModel(Wav2Vec2Model):
|
|
pass
|
|
|
|
|
|
class WavLMForCTC(Wav2Vec2ForCTC):
|
|
pass
|
|
|
|
|
|
class WavLMForSequenceClassification(Wav2Vec2ForSequenceClassification):
|
|
pass
|
|
|
|
|
|
class WavLMForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
|
|
pass
|
|
|
|
|
|
class WavLMForXVector(Wav2Vec2ForXVector):
|
|
pass
|
|
|
|
|
|
__all__ = [
|
|
"WavLMForAudioFrameClassification",
|
|
"WavLMForCTC",
|
|
"WavLMForSequenceClassification",
|
|
"WavLMForXVector",
|
|
"WavLMModel",
|
|
"WavLMPreTrainedModel",
|
|
]
|