# coding=utf-8 # Copyright 2023 IBM and HuggingFace Inc. team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch PatchTSMixer model.""" import math from dataclasses import dataclass from typing import Callable, Optional, Union import torch import torch.nn as nn from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import auto_docstring, logging from ...utils.deprecation import deprecate_kwarg from .configuration_patchtsmixer import PatchTSMixerConfig logger = logging.get_logger(__name__) class PatchTSMixerGatedAttention(nn.Module): """ Module that applies gated attention to input data. Args: in_size (`int`): The input size. out_size (`int`): The output size. """ def __init__(self, in_size: int, out_size: int): super().__init__() self.attn_layer = nn.Linear(in_size, out_size) self.attn_softmax = nn.Softmax(dim=-1) def forward(self, inputs): attn_weight = self.attn_softmax(self.attn_layer(inputs)) inputs = inputs * attn_weight return inputs # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTBatchNorm with PatchTST->PatchTSMixer class PatchTSMixerBatchNorm(nn.Module): """ Compute batch normalization over the sequence length (time) dimension. """ def __init__(self, config: PatchTSMixerConfig): super().__init__() self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps) def forward(self, inputs: torch.Tensor): """ Parameters: inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`): input for Batch norm calculation Returns: `torch.Tensor` of shape `(batch_size, sequence_length, d_model)` """ output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length) output = self.batchnorm(output) return output.transpose(1, 2) class PatchTSMixerPositionalEncoding(nn.Module): """ Class for positional encoding """ def __init__(self, config: PatchTSMixerConfig): super().__init__() # positional encoding: [num_patches x d_model] if config.use_positional_encoding: self.position_enc = self._init_pe(config) else: self.position_enc = nn.Parameter(torch.zeros(config.num_patches, config.d_model)) @staticmethod def _init_pe(config: PatchTSMixerConfig) -> nn.Parameter: # Positional encoding if config.positional_encoding_type == "random": position_enc = nn.Parameter(torch.randn(config.num_patches, config.d_model), requires_grad=True) elif config.positional_encoding_type == "sincos": position_enc = torch.zeros(config.num_patches, config.d_model) position = torch.arange(0, config.num_patches).unsqueeze(1) div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model)) position_enc[:, 0::2] = torch.sin(position * div_term) position_enc[:, 1::2] = torch.cos(position * div_term) position_enc = position_enc - position_enc.mean() position_enc = position_enc / (position_enc.std() * 10) position_enc = nn.Parameter(position_enc, requires_grad=False) else: raise ValueError( f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'." ) return position_enc def forward(self, patch_input: torch.Tensor): # hidden_state: [bs x num_channels x num_patches x d_model] hidden_state = patch_input + self.position_enc return hidden_state class PatchTSMixerNormLayer(nn.Module): """Normalization block Args: config (`PatchTSMixerConfig`): Configuration. """ def __init__(self, config: PatchTSMixerConfig): super().__init__() self.norm_mlp = config.norm_mlp if "batch" in config.norm_mlp.lower(): self.norm = PatchTSMixerBatchNorm(config) else: self.norm = nn.LayerNorm(config.d_model, eps=config.norm_eps) def forward(self, inputs: torch.Tensor): """ Args: inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`): Input to the normalization layer. Returns: `torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))` """ if "batch" in self.norm_mlp.lower(): # reshape the data inputs_reshaped = torch.reshape( inputs, ( inputs.shape[0] * inputs.shape[1], inputs.shape[2], inputs.shape[3], ), ) # inputs_reshaped: [batch_size*num_channels, num_patches, d_model] # inputs_reshaped: [batch_size*num_channels, num_patches, d_model] inputs_reshaped = self.norm(inputs_reshaped) # put back data to the original shape inputs = torch.reshape(inputs_reshaped, inputs.shape) else: inputs = self.norm(inputs) return inputs class PatchTSMixerMLP(nn.Module): def __init__(self, in_features, out_features, config): super().__init__() num_hidden = in_features * config.expansion_factor self.fc1 = nn.Linear(in_features, num_hidden) self.dropout1 = nn.Dropout(config.dropout) self.fc2 = nn.Linear(num_hidden, out_features) self.dropout2 = nn.Dropout(config.dropout) def forward(self, inputs: torch.Tensor): """ Args: inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`): Input to the MLP layer. Returns: `torch.Tensor` of the same shape as `inputs` """ inputs = self.dropout1(nn.functional.gelu(self.fc1(inputs))) inputs = self.fc2(inputs) inputs = self.dropout2(inputs) return inputs class PatchTSMixerChannelFeatureMixerBlock(nn.Module): """This module mixes the features in the channel dimension. Args: config (`PatchTSMixerConfig`): Configuration. """ def __init__(self, config: PatchTSMixerConfig): super().__init__() self.norm = PatchTSMixerNormLayer(config) self.gated_attn = config.gated_attn self.mlp = PatchTSMixerMLP( in_features=config.num_input_channels, out_features=config.num_input_channels, config=config, ) if config.gated_attn: self.gating_block = PatchTSMixerGatedAttention( in_size=config.num_input_channels, out_size=config.num_input_channels ) def forward(self, inputs: torch.Tensor): """ Args: inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`): input to the MLP layer Returns: `torch.Tensor` of the same shape as `inputs` """ residual = inputs inputs = self.norm(inputs) inputs = inputs.permute(0, 3, 2, 1) if self.gated_attn: inputs = self.gating_block(inputs) inputs = self.mlp(inputs) inputs = inputs.permute(0, 3, 2, 1) out = inputs + residual return out # Copied from transformers.models.bart.modeling_bart.eager_attention_forward def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, head_mask: Optional[torch.Tensor] = None, **kwargs, ): if scaling is None: scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) if head_mask is not None: attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->PatchTSMixer class PatchTSMixerAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, is_causal: bool = False, config: Optional[PatchTSMixerConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) current_states = key_value_states if is_cross_attention else hidden_states key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) return attn_output, attn_weights, None class PatchMixerBlock(nn.Module): """This module mixes the patch dimension. Args: config (`PatchTSMixerConfig`): Configuration. """ def __init__(self, config: PatchTSMixerConfig): super().__init__() self.norm = PatchTSMixerNormLayer(config) self.self_attn = config.self_attn self.gated_attn = config.gated_attn self.mlp = PatchTSMixerMLP( in_features=config.num_patches, out_features=config.num_patches, config=config, ) if config.gated_attn: self.gating_block = PatchTSMixerGatedAttention(in_size=config.num_patches, out_size=config.num_patches) if config.self_attn: self.self_attn_layer = PatchTSMixerAttention( embed_dim=config.d_model, num_heads=config.self_attn_heads, dropout=config.dropout, config=config, ) self.norm_attn = PatchTSMixerNormLayer(config) def forward(self, hidden_state): """ Args: hidden_state (`torch.Tensor`): Input tensor. Returns: `torch.Tensor`: Transformed tensor. """ residual = hidden_state hidden_state = self.norm(hidden_state) if self.self_attn: batch_size, n_vars, num_patches, d_model = hidden_state.shape hidden_state_reshaped = hidden_state.reshape(batch_size * n_vars, num_patches, d_model) x_attn, _, _ = self.self_attn_layer(hidden_state_reshaped, output_attentions=False) x_attn = x_attn.reshape(batch_size, n_vars, num_patches, d_model) # Transpose so that num_patches is the last dimension hidden_state = hidden_state.transpose(2, 3) hidden_state = self.mlp(hidden_state) if self.gated_attn: hidden_state = self.gating_block(hidden_state) # Transpose back hidden_state = hidden_state.transpose(2, 3) if self.self_attn: hidden_state = self.norm_attn(hidden_state + x_attn) out = hidden_state + residual return out class FeatureMixerBlock(nn.Module): """This module mixes the hidden feature dimension. Args: config (`PatchTSMixerConfig`): Configuration. """ def __init__(self, config: PatchTSMixerConfig): super().__init__() self.norm = PatchTSMixerNormLayer(config) self.gated_attn = config.gated_attn self.mlp = PatchTSMixerMLP( in_features=config.d_model, out_features=config.d_model, config=config, ) if config.gated_attn: self.gating_block = PatchTSMixerGatedAttention(in_size=config.d_model, out_size=config.d_model) def forward(self, hidden: torch.Tensor): """ Args: hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`): Input tensor to the layer. Returns: `torch.Tensor`: Transformed tensor. """ residual = hidden hidden = self.norm(hidden) hidden = self.mlp(hidden) if self.gated_attn: hidden = self.gating_block(hidden) out = hidden + residual return out class PatchTSMixerLayer(nn.Module): """ The `PatchTSMixer` layer that does all three kinds of mixing. Args: config (`PatchTSMixerConfig`): Configuration. """ def __init__(self, config: PatchTSMixerConfig): super().__init__() self.patch_mixer = PatchMixerBlock(config=config) self.feature_mixer = FeatureMixerBlock(config=config) self.mode = config.mode if config.mode == "mix_channel": self.channel_feature_mixer = PatchTSMixerChannelFeatureMixerBlock(config=config) def forward(self, hidden: torch.Tensor): """ Args: hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`): Input tensor to the layer. Returns: `torch.Tensor`: Transformed tensor. """ if self.mode == "mix_channel": hidden = self.channel_feature_mixer(hidden) hidden = self.patch_mixer(hidden) hidden = self.feature_mixer(hidden) # hidden: (batch_size x num_patches x d_model) return hidden class PatchTSMixerBlock(nn.Module): """The main computing framework of the `PatchTSMixer` model. Args: config (`PatchTSMixerConfig`): Configuration. """ def __init__(self, config: PatchTSMixerConfig): super().__init__() num_layers = config.num_layers self.mixers = nn.ModuleList([PatchTSMixerLayer(config=config) for _ in range(num_layers)]) def forward(self, hidden_state, output_hidden_states: bool = False): """ Args: hidden_state (`torch.Tensor`): The input tensor. output_hidden_states (`bool`, *optional*, defaults to False.): Whether to output the hidden states as well. Returns: `torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to `True`. """ all_hidden_states = [] embedding = hidden_state for mod in self.mixers: embedding = mod(embedding) if output_hidden_states: all_hidden_states.append(embedding) if output_hidden_states: return embedding, all_hidden_states else: return embedding, None class PatchTSMixerForPredictionHead(nn.Module): """Prediction Head for Forecasting Args: config (`PatchTSMixerConfig`): Configuration. """ def __init__(self, config: PatchTSMixerConfig, distribution_output=None): super().__init__() self.prediction_channel_indices = config.prediction_channel_indices if self.prediction_channel_indices is not None: self.prediction_channel_indices.sort() self.dropout_layer = nn.Dropout(config.head_dropout) if distribution_output is None: self.base_forecast_block = nn.Linear((config.num_patches * config.d_model), config.prediction_length) else: self.base_forecast_block = distribution_output.get_parameter_projection( config.num_patches * config.d_model ) self.flatten = nn.Flatten(start_dim=-2) def forward(self, hidden_features): """ Args: hidden_features (`torch.Tensor` of shape `(batch_size, num_patch, d_model)` in `flatten` mode or `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden features. Returns: `torch.Tensor` of shape `(batch_size, prediction_length, nvars)`. """ hidden_features = self.flatten(hidden_features) # [batch_size x n_vars x num_patch * d_model] hidden_features = self.dropout_layer(hidden_features) # [batch_size x n_vars x num_patch * d_model] forecast = self.base_forecast_block(hidden_features) # [batch_size x n_vars x prediction_length] if isinstance(forecast, tuple): forecast = tuple(z.transpose(-1, -2) for z in forecast) else: forecast = forecast.transpose(-1, -2) # [batch_size x prediction_length x n_vars] if self.prediction_channel_indices is not None: if isinstance(forecast, tuple): forecast = tuple(z[..., self.prediction_channel_indices] for z in forecast) else: forecast = forecast[..., self.prediction_channel_indices] # [batch_size x prediction_length x n_vars] return forecast class PatchTSMixerLinearHead(nn.Module): """Linear head for Classification and Regression. Args: config (`PatchTSMixerConfig`): Configuration. """ def __init__(self, config: PatchTSMixerConfig, distribution_output=None): super().__init__() self.head_aggregation = config.head_aggregation self.output_range = config.output_range if config.head_aggregation is None: mul_factor = config.num_patches else: mul_factor = 1 self.distribution_output = distribution_output if distribution_output is None: self.projection = nn.Linear( config.d_model * config.num_input_channels * mul_factor, config.num_targets, ) else: self.projection = distribution_output.get_parameter_projection( config.d_model * config.num_input_channels * mul_factor ) if config.head_aggregation is None: self.flatten = nn.Flatten(start_dim=-3) else: self.flatten = nn.Flatten(start_dim=-2) self.dropout = nn.Dropout(config.head_dropout) def forward(self, hidden_features): """ Args: hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden features. Returns: `torch.Tensor` of shape `(batch_size x num_targets)`. """ # batch_size x d_model x num_patch or batch_size x n_vars x d_model x num_patch hidden_features = hidden_features.transpose(-1, -2) if self.head_aggregation == "use_last": # batch_size x d_model (flatten) or # batch_size x n_vars x d_model (common_channel) hidden_features = hidden_features[..., -1] elif self.head_aggregation == "max_pool": # batch_size x n_vars x d_model or batch_size x d_model hidden_features = hidden_features.max(dim=-1).values elif self.head_aggregation == "avg_pool": # batch_size x n_vars x d_model or batch_size x d_model hidden_features = hidden_features.mean(dim=-1) if self.flatten: hidden_features = self.flatten(hidden_features) hidden_features = self.dropout(hidden_features) hidden_features = self.projection(hidden_features) # batch_size x num_targets if (self.distribution_output is None) and (self.output_range is not None): hidden_features = ( torch.sigmoid(hidden_features) * (self.output_range[1] - self.output_range[0]) + self.output_range[0] ) return hidden_features @auto_docstring class PatchTSMixerPreTrainedModel(PreTrainedModel): # Weight initialization config: PatchTSMixerConfig base_model_prefix = "model" main_input_name = "past_values" supports_gradient_checkpointing = False def _init_weights(self, module): """Initialize weights""" if isinstance(module, PatchTSMixerPositionalEncoding): # initialize positional encoding if self.config.positional_encoding_type == "random": nn.init.normal_(module.position_enc, mean=0.0, std=0.1) elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, PatchTSMixerBatchNorm): module.batchnorm.bias.data.zero_() module.batchnorm.weight.data.fill_(1.0) elif isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.init_std) if module.bias is not None: module.bias.data.zero_() class PatchTSMixerPretrainHead(nn.Module): """Pretraining head. Args: config (`PatchTSMixerConfig`): Configuration. """ def __init__(self, config: PatchTSMixerConfig): super().__init__() self.dropout_layer = nn.Dropout(config.head_dropout) self.base_pt_block = nn.Linear(config.d_model, config.patch_length) def forward(self, hidden_features): """ Args: hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden features. Returns: `torch.Tensor` of shape `(batch_size x n_vars x num_patch x patch_length)`. """ hidden_features = self.dropout_layer(hidden_features) forecast = self.base_pt_block(hidden_features) # [batch_size x n_vars x num_patch x patch_length] return forecast # Copied from transformers.models.patchtst.modeling_patchtst.random_masking def random_masking( inputs: torch.Tensor, mask_ratio: float, unmasked_channel_indices: Optional[list] = None, channel_consistent_masking: bool = False, mask_value: int = 0, ): """random_masking: Mask the input considering the control variables. Args: inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`): The input tensor to mask. mask_ratio (`float`): Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1. unmasked_channel_indices (list, *optional*): Indices of channels that will not be masked. channel_consistent_masking (bool, *optional*, defaults to `False`): When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary across channels. mask_value (int, *optional*, defaults to 0): Define the value of masked patches for pretraining. Returns: `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x n] """ if mask_ratio < 0 or mask_ratio >= 1: raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.") batch_size, num_channels, sequence_length, num_features = inputs.shape device = inputs.device len_keep = int(sequence_length * (1 - mask_ratio)) if channel_consistent_masking: noise = torch.rand(batch_size, 1, sequence_length, device=device) # noise in [0, 1], bs x 1 x L noise = noise.repeat(1, num_channels, 1) # bs x num_channels x time else: # noise in [0, 1], bs x num_channels x L noise = torch.rand(batch_size, num_channels, sequence_length, device=device) # mask: [bs x num_channels x num_patch] mask = torch.ones(batch_size, num_channels, sequence_length, device=device) mask[:, :, :len_keep] = 0 # sort noise for each sample ids_shuffle = torch.argsort(noise, dim=-1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=-1) # ids_restore: [bs x num_channels x L] mask = torch.gather(mask, dim=-1, index=ids_restore) mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patches x patch_length] if unmasked_channel_indices is not None: mask[:, unmasked_channel_indices, :, :] = 0 inputs_mask = inputs.masked_fill(mask.bool(), mask_value) return inputs_mask, mask[..., 0] # Copied from transformers.models.patchtst.modeling_patchtst.forecast_masking def forecast_masking( inputs: torch.Tensor, num_forecast_mask_patches: Union[list, int], unmasked_channel_indices: Optional[list] = None, mask_value: int = 0, ): """Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches. If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list. Parameters: inputs (`torch.Tensor`): Input of shape `(bs, num_channels, num_patch, patch_length)` num_forecast_mask_patches (`list`): Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5]. unmasked_channel_indices (`list`, *optional*): Indices of channels that are not masked. mask_value (`int`, *optional*, defaults to 0): Values in the masked patches will be filled by `mask_value`. Returns: `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs, num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)` """ if isinstance(num_forecast_mask_patches, int): num_forecast_mask_patches = [num_forecast_mask_patches] forecast_mask_ratios = [1 for _ in num_forecast_mask_patches] batch_size, num_channels, sequence_length, num_features = inputs.shape mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device) t_list = [] total_length = 0 total_ratio = sum(forecast_mask_ratios) for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios): if patch_length <= 0 or patch_length >= sequence_length: raise ValueError( f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches." ) temp_len = int(batch_size * ratio / total_ratio) t_list.append([patch_length, ratio, temp_len]) total_length += temp_len t_list = sorted(t_list, key=lambda x: x[2]) if total_length < batch_size: t_list[0][2] = t_list[0][2] + (batch_size - total_length) elif total_length > batch_size: t_list[-1][2] = t_list[-1][2] + (total_length - batch_size) batch1 = 0 for patch_len, _, temp_len in t_list: batch2 = batch1 + temp_len mask[batch1:batch2, :, -patch_len:] = 1 batch1 = batch2 perm = torch.randperm(mask.shape[0]) mask = mask[perm] mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patch x patch_len] if unmasked_channel_indices is not None: mask[:, unmasked_channel_indices, :, :] = 0 inputs_mask = inputs.masked_fill(mask.bool(), mask_value) return inputs_mask, mask[..., 0] # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTPatchify with PatchTST->PatchTSMixer class PatchTSMixerPatchify(nn.Module): """ A class to patchify the time series sequence into different patches Returns: `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)` """ def __init__(self, config: PatchTSMixerConfig): super().__init__() self.sequence_length = config.context_length self.patch_length = config.patch_length self.patch_stride = config.patch_stride if self.sequence_length <= self.patch_length: raise ValueError( f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})" ) # get the number of patches self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1 new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1) self.sequence_start = self.sequence_length - new_sequence_length def forward(self, past_values: torch.Tensor): """ Parameters: past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*): Input for patchification Returns: `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)` """ sequence_length = past_values.shape[-2] if sequence_length != self.sequence_length: raise ValueError( f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})." ) # output: [bs x new_sequence_length x num_channels] output = past_values[:, self.sequence_start :, :] # output: [bs x num_patches x num_input_channels x patch_length] output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride) # output: [bs x num_input_channels x num_patches x patch_length] output = output.transpose(-2, -3).contiguous() return output # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTMasking with PatchTST->PatchTSMixer class PatchTSMixerMasking(nn.Module): """ Class to perform random or forecast masking. Parameters: config (`PatchTSMixerConfig`): model config Returns: x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`) Masked patched input mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`) Bool tensor indicating True on masked points """ def __init__(self, config: PatchTSMixerConfig): super().__init__() self.random_mask_ratio = config.random_mask_ratio self.channel_consistent_masking = config.channel_consistent_masking self.mask_type = config.mask_type self.num_forecast_mask_patches = config.num_forecast_mask_patches self.unmasked_channel_indices = config.unmasked_channel_indices self.mask_value = config.mask_value if self.unmasked_channel_indices is not None: self.unmasked_channel_indices = sorted(self.unmasked_channel_indices) def forward(self, patch_input: torch.Tensor): """ Parameters: patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*): Patch input Return: masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`) Masked patched input mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`) Bool tensor indicating True on masked points """ if self.mask_type == "random": masked_input, mask = random_masking( inputs=patch_input, mask_ratio=self.random_mask_ratio, unmasked_channel_indices=self.unmasked_channel_indices, channel_consistent_masking=self.channel_consistent_masking, mask_value=self.mask_value, ) elif self.mask_type == "forecast": masked_input, mask = forecast_masking( inputs=patch_input, num_forecast_mask_patches=self.num_forecast_mask_patches, unmasked_channel_indices=self.unmasked_channel_indices, mask_value=self.mask_value, ) else: raise ValueError(f"Invalid mask type {self.mask_type}.") # mask: [bs x num_input_channels x num_patch] mask = mask.bool() return masked_input, mask # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTStdScaler with PatchTST->PatchTSMixer class PatchTSMixerStdScaler(nn.Module): """ Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by subtracting from the mean and dividing by the standard deviation. """ def __init__(self, config: PatchTSMixerConfig): super().__init__() self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 self.keepdim = config.keepdim if hasattr(config, "keepdim") else True self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5 def forward( self, data: torch.Tensor, observed_indicator: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Parameters: data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): input for Batch norm calculation observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): Calculating the scale on the observed indicator. Returns: tuple of `torch.Tensor` of shapes (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, `(batch_size, 1, num_input_channels)`) """ denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim) denominator = denominator.clamp_min(1.0) loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator scale = torch.sqrt(variance + self.minimum_scale) return (data - loc) / scale, loc, scale # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTMeanScaler with PatchTST->PatchTSMixer class PatchTSMixerMeanScaler(nn.Module): """ Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data accordingly. """ def __init__(self, config: PatchTSMixerConfig): super().__init__() self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 self.keepdim = config.keepdim if hasattr(config, "keepdim") else True self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 self.default_scale = config.default_scale if hasattr(config, "default_scale") else None def forward( self, data: torch.Tensor, observed_indicator: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Parameters: data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): input for Batch norm calculation observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): Calculating the scale on the observed indicator. Returns: tuple of `torch.Tensor` of shapes (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, `(batch_size, 1, num_input_channels)`) """ ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True) num_observed = observed_indicator.sum(self.dim, keepdim=True) scale = ts_sum / torch.clamp(num_observed, min=1) # If `default_scale` is provided, we use it, otherwise we use the scale # of the batch. if self.default_scale is None: batch_sum = ts_sum.sum(dim=0) batch_observations = torch.clamp(num_observed.sum(0), min=1) default_scale = torch.squeeze(batch_sum / batch_observations) else: default_scale = self.default_scale * torch.ones_like(scale) # apply default scale where there are no observations scale = torch.where(num_observed > 0, scale, default_scale) # ensure the scale is at least `self.minimum_scale` scale = torch.clamp(scale, min=self.minimum_scale) scaled_data = data / scale if not self.keepdim: scale = scale.squeeze(dim=self.dim) return scaled_data, torch.zeros_like(scale), scale # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTNOPScaler with PatchTST->PatchTSMixer class PatchTSMixerNOPScaler(nn.Module): """ Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. """ def __init__(self, config: PatchTSMixerConfig): super().__init__() self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 self.keepdim = config.keepdim if hasattr(config, "keepdim") else True def forward( self, data: torch.Tensor, observed_indicator: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Parameters: data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): input for Batch norm calculation Returns: tuple of `torch.Tensor` of shapes (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, `(batch_size, 1, num_input_channels)`) """ scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) return data, loc, scale @dataclass @auto_docstring( custom_intro=""" Base class for `PatchTSMixerEncoderOutput`, with potential hidden states. """ ) class PatchTSMixerEncoderOutput(ModelOutput): r""" last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`): Hidden-state at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*): Hidden-states of the model at the output of each layer. """ last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None class PatchTSMixerEncoder(PatchTSMixerPreTrainedModel): """ Encoder for PatchTSMixer which inputs patched time-series and outputs patched embeddings. Args: config (`PatchTSMixerConfig`): Configuration. """ def __init__(self, config: PatchTSMixerConfig): super().__init__(config) self.use_return_dict = config.use_return_dict self.patcher = nn.Linear(config.patch_length, config.d_model) if config.use_positional_encoding: self.positional_encoder = PatchTSMixerPositionalEncoding(config=config) else: self.positional_encoder = None self.mlp_mixer_encoder = PatchTSMixerBlock(config=config) # Initialize weights and apply final processing if config.post_init: self.post_init() @auto_docstring def forward( self, past_values: torch.Tensor, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> Union[tuple, PatchTSMixerEncoderOutput]: r""" past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`): Context values of the time series. For a pretraining task, this denotes the input time series to predict the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly, for classification or regression tasks, it denotes the appropriate context values of the time series. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is greater than 1. Returns: `torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)` """ return_dict = return_dict if return_dict is not None else self.use_return_dict # flatten [bs x num_patch x d_model]. common_channel/mix_channel: [bs x n_vars x num_patch x d_model] patches = self.patcher(past_values) # add positional encoder if self.positional_encoder is not None: patches = self.positional_encoder(patches) last_hidden_state, hidden_states = self.mlp_mixer_encoder(patches, output_hidden_states=output_hidden_states) if not return_dict: return tuple( v for v in [ last_hidden_state, hidden_states, ] ) return PatchTSMixerEncoderOutput(last_hidden_state=last_hidden_state, hidden_states=hidden_states) @dataclass @auto_docstring( custom_intro=""" Base class for model's outputs, with potential hidden states. """ ) class PatchTSMixerModelOutput(ModelOutput): r""" last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`): Hidden-state at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*): Hidden-states of the model at the output of each layer. patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`): Patched input data to the model. mask (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*): Bool Tensor indicating True in masked patches and False otherwise. loc (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*): Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin enabled. scale (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*): Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin enabled. """ last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None patch_input: Optional[torch.FloatTensor] = None mask: Optional[torch.FloatTensor] = None loc: Optional[torch.FloatTensor] = None scale: Optional[torch.FloatTensor] = None @auto_docstring( custom_intro=""" The PatchTSMixer Model for time-series forecasting. """ ) class PatchTSMixerModel(PatchTSMixerPreTrainedModel): def __init__(self, config: PatchTSMixerConfig, mask_input: bool = False): r""" mask_input (bool, *optional*, defaults to `False`): Whether to mask the input using the [`PatchTSMixerMasking`] module. """ super().__init__(config) self.use_return_dict = config.use_return_dict self.encoder = PatchTSMixerEncoder(config) self.patching = PatchTSMixerPatchify(config) if mask_input is True: self.masking = PatchTSMixerMasking(config) else: self.masking = None if config.scaling == "mean": self.scaler = PatchTSMixerMeanScaler(config) elif config.scaling == "std" or config.scaling is True: self.scaler = PatchTSMixerStdScaler(config) else: self.scaler = PatchTSMixerNOPScaler(config) # Initialize weights and apply final processing if config.post_init: self.post_init() @auto_docstring def forward( self, past_values: torch.Tensor, observed_mask: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> PatchTSMixerModelOutput: r""" past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`): Context values of the time series. For a pretraining task, this denotes the input time series to predict the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly, for classification or regression tasks, it denotes the appropriate context values of the time series. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is greater than 1. observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in `[0, 1]`: - 1 for values that are **observed**, - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). """ return_dict = return_dict if return_dict is not None else self.use_return_dict mask = None if observed_mask is None: observed_mask = torch.ones_like(past_values) scaled_past_values, loc, scale = self.scaler(past_values, observed_mask) patched_x = self.patching(scaled_past_values) # [batch_size x num_input_channels x num_patch x patch_length enc_input = patched_x if self.masking is not None: enc_input, mask = self.masking(patched_x) # enc_input: [batch_size x num_input_channels x num_patch x patch_length] # mask: [batch_size x num_input_channels x num_patch] encoder_output = self.encoder( enc_input, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if isinstance(encoder_output, tuple): encoder_output = PatchTSMixerEncoderOutput(*encoder_output) if not return_dict: return tuple( v for v in [ encoder_output.last_hidden_state, encoder_output.hidden_states, patched_x, mask, loc, scale, ] ) return PatchTSMixerModelOutput( last_hidden_state=encoder_output.last_hidden_state, hidden_states=encoder_output.hidden_states, patch_input=patched_x, mask=mask, loc=loc, scale=scale, ) @dataclass @auto_docstring( custom_intro=""" Output type of [`PatchTSMixerForPreTrainingOutput`]. """ ) class PatchTSMixerForPreTrainingOutput(ModelOutput): r""" loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`): Total loss prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, patch_length)`): Prediction output from the pretrain head. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`): Backbone embeddings before passing through the head. hidden_states (`tuple(torch.FloatTensor)`, *optional*): Hidden-states of the model at the output of each layer. """ loss: Optional[torch.FloatTensor] = None prediction_outputs: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None @auto_docstring( custom_intro=""" `PatchTSMixer` for mask pretraining. """ ) class PatchTSMixerForPretraining(PatchTSMixerPreTrainedModel): def __init__(self, config: PatchTSMixerConfig): super().__init__(config) self.model = PatchTSMixerModel(config, mask_input=True) self.head = PatchTSMixerPretrainHead(config=config) self.masked_loss = config.masked_loss self.use_return_dict = config.use_return_dict # Initialize weights and apply final processing if config.post_init: self.post_init() @auto_docstring def forward( self, past_values: torch.Tensor, observed_mask: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = False, return_loss: bool = True, return_dict: Optional[bool] = None, ) -> PatchTSMixerForPreTrainingOutput: r""" past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`): Context values of the time series. For a pretraining task, this denotes the input time series to predict the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly, for classification or regression tasks, it denotes the appropriate context values of the time series. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is greater than 1. observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in `[0, 1]`: - 1 for values that are **observed**, - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). return_loss (`bool`, *optional*): Whether to return the loss in the `forward` call. """ return_dict = return_dict if return_dict is not None else self.use_return_dict if self.masked_loss is True: loss = torch.nn.MSELoss(reduction="none") else: loss = torch.nn.MSELoss(reduction="mean") # past_values: tensor [batch_size x context_length x num_input_channels] model_output = self.model( past_values, observed_mask=observed_mask, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # x.last_hidden_state: [batch_size x nvars x num_patch x d_model] if isinstance(model_output, tuple): model_output = PatchTSMixerModelOutput(*model_output) x_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x nvars x num_patch x patch_length] if return_loss is True: loss_val = loss(x_hat, model_output.patch_input) else: loss_val = None # calculate masked_loss if self.masked_loss is True and loss_val is not None: loss_val = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10) if not return_dict: return tuple( v for v in [ loss_val, x_hat, model_output.last_hidden_state, model_output.hidden_states, ] ) return PatchTSMixerForPreTrainingOutput( loss=loss_val, prediction_outputs=x_hat, # tensor [batch_size x nvars x num_patch x patch_length] last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model] hidden_states=model_output.hidden_states, ) @dataclass @auto_docstring( custom_intro=""" Output type of [`PatchTSMixerForPredictionOutput`]. """ ) class PatchTSMixerForPredictionOutput(ModelOutput): r""" loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`): Total loss. prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`): Prediction output from the forecast head. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`): Backbone embeddings before passing through the head. hidden_states (`tuple(torch.FloatTensor)`, *optional*): Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`): Input mean scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`): Input std dev """ loss: Optional[torch.FloatTensor] = None prediction_outputs: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None loc: Optional[torch.FloatTensor] = None scale: Optional[torch.FloatTensor] = None @dataclass @auto_docstring( custom_intro=""" Base class for time series model's predictions outputs that contains the sampled values from the chosen distribution. """ ) class SamplePatchTSMixerPredictionOutput(ModelOutput): r""" sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`): Sampled values from the chosen distribution. """ sequences: Optional[torch.FloatTensor] = None @dataclass @auto_docstring( custom_intro=""" Base class for time series model's predictions outputs that contains the sampled values from the chosen distribution. """ ) class SamplePatchTSMixerRegressionOutput(ModelOutput): r""" sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`): Sampled values from the chosen distribution. """ sequences: Optional[torch.FloatTensor] = None # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: """ Computes the negative log likelihood loss from input distribution with respect to target. """ return -input.log_prob(target) # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: """ Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. Args: input_tensor (`torch.FloatTensor`): Input tensor, of which the average must be computed. weights (`torch.FloatTensor`, *optional*): Weights tensor, of the same shape as `input_tensor`. dim (`int`, *optional*): The dim along which to average `input_tensor`. Returns: `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. """ if weights is not None: weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights else: return input_tensor.mean(dim=dim) class PatchTSMixerForPrediction(PatchTSMixerPreTrainedModel): r""" `PatchTSMixer` for forecasting application. Args: config (`PatchTSMixerConfig`): Configuration. Returns: `None`. """ def __init__(self, config: PatchTSMixerConfig): super().__init__(config) self.loss = config.loss self.use_return_dict = config.use_return_dict self.prediction_channel_indices = config.prediction_channel_indices self.num_parallel_samples = config.num_parallel_samples if config.loss == "mse": self.distribution_output = None else: dim = config.prediction_length distribution_output_map = { "student_t": StudentTOutput, "normal": NormalOutput, "negative_binomial": NegativeBinomialOutput, } output_class = distribution_output_map.get(config.distribution_output, None) if output_class is not None: self.distribution_output = output_class(dim=dim) else: raise ValueError(f"Unknown distribution output {config.distribution_output}") self.model = PatchTSMixerModel(config) self.head = PatchTSMixerForPredictionHead( config=config, distribution_output=self.distribution_output, ) # Initialize weights and apply final processing if config.post_init: self.post_init() @auto_docstring def forward( self, past_values: torch.Tensor, observed_mask: Optional[torch.Tensor] = None, future_values: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = False, return_loss: bool = True, return_dict: Optional[bool] = None, ) -> PatchTSMixerForPredictionOutput: r""" past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`): Context values of the time series. For a pretraining task, this denotes the input time series to predict the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly, for classification or regression tasks, it denotes the appropriate context values of the time series. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is greater than 1. observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in `[0, 1]`: - 1 for values that are **observed**, - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,: `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target values of the time series, that serve as labels for the model. The `future_values` is what the Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT required for a pretraining task. For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter, pass the target data with all channels, as channel Filtering for both prediction and target will be manually applied before the loss computation. return_loss (`bool`, *optional*): Whether to return the loss in the `forward` call. """ if self.loss == "mse": loss = nn.MSELoss(reduction="mean") elif self.loss == "nll": loss = nll else: raise ValueError("Invalid loss function: Allowed values: mse and nll") return_dict = return_dict if return_dict is not None else self.use_return_dict # past_values: tensor [batch_size x context_length x num_input_channels] model_output = self.model( past_values, observed_mask=observed_mask, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # model_output: [batch_size x nvars x num_patch x d_model] if isinstance(model_output, tuple): model_output = PatchTSMixerModelOutput(*model_output) # tensor [batch_size x prediction_length x num_input_channels] y_hat = self.head(model_output.last_hidden_state) loss_val = None if self.prediction_channel_indices is not None: if self.distribution_output: distribution = self.distribution_output.distribution( y_hat, loc=model_output.loc[..., self.prediction_channel_indices], scale=model_output.scale[..., self.prediction_channel_indices], ) if future_values is not None and return_loss is True: loss_val = loss( distribution, future_values[..., self.prediction_channel_indices], ) # take average of the loss loss_val = weighted_average(loss_val) else: y_hat = ( y_hat * model_output.scale[..., self.prediction_channel_indices] + model_output.loc[..., self.prediction_channel_indices] ) if future_values is not None and return_loss is True: loss_val = loss(y_hat, future_values[..., self.prediction_channel_indices]) else: if self.distribution_output: distribution = self.distribution_output.distribution( y_hat, loc=model_output.loc, scale=model_output.scale ) if future_values is not None and return_loss is True: loss_val = loss(distribution, future_values) loss_val = weighted_average(loss_val) else: y_hat = y_hat * model_output.scale + model_output.loc if future_values is not None and return_loss is True: loss_val = loss(y_hat, future_values) if self.prediction_channel_indices is not None: loc = model_output.loc[..., self.prediction_channel_indices] scale = model_output.scale[..., self.prediction_channel_indices] else: loc = model_output.loc scale = model_output.scale if not return_dict: return tuple( v for v in [ loss_val, y_hat, model_output.last_hidden_state, model_output.hidden_states, loc, scale, ] ) return PatchTSMixerForPredictionOutput( loss=loss_val, prediction_outputs=y_hat, # tensor [batch_size x prediction_length x num_input_channels] last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model] hidden_states=model_output.hidden_states, loc=loc, scale=scale, ) @torch.no_grad() def generate( self, past_values: torch.Tensor, observed_mask: Optional[torch.Tensor] = None, ) -> SamplePatchTSMixerPredictionOutput: """ Generate sequences of sample predictions from a model with a probability distribution head. Args: past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`): Past values of the time series that serves as context in order to predict the future. observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in `[0, 1]`: - 1 for values that are **observed**, - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). Return: [`SamplePatchTSMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of samples, prediction_length, num_input_channels)`. """ # get number of samples num_parallel_samples = self.num_parallel_samples # get model output outputs = self( past_values=past_values, future_values=None, observed_mask=observed_mask, output_hidden_states=False, ) # get distribution distribution = self.distribution_output.distribution( outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale ) # get samples: list of [batch_size x prediction_length x num_channels] samples = [distribution.sample() for _ in range(num_parallel_samples)] # stack tensors samples = torch.stack(samples, dim=1) # [batch_size x num_samples x prediction_length x num_channels] return SamplePatchTSMixerPredictionOutput(sequences=samples) @dataclass @auto_docstring( custom_intro=""" Output type of [`PatchTSMixerForTimeSeriesClassificationOutput`]. """ ) class PatchTSMixerForTimeSeriesClassificationOutput(ModelOutput): r""" loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`): Total loss. prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_labels)`): Prediction output from the classification head. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`): Backbone embeddings before passing through the head. hidden_states (`tuple(torch.FloatTensor)`, *optional*): Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. """ loss: Optional[torch.FloatTensor] = None prediction_outputs: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel): r""" `PatchTSMixer` for classification application. Args: config (`PatchTSMixerConfig`): Configuration. Returns: `None`. """ def __init__(self, config: PatchTSMixerConfig): super().__init__(config) self.model = PatchTSMixerModel(config) self.head = PatchTSMixerLinearHead( config=config, ) self.use_return_dict = config.use_return_dict if config.scaling in ["std", "mean", True]: self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches) else: self.inject_scale = None # Initialize weights and apply final processing if config.post_init: self.post_init() @auto_docstring def forward( self, past_values: torch.Tensor, target_values: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = False, return_loss: bool = True, return_dict: Optional[bool] = None, ) -> PatchTSMixerForTimeSeriesClassificationOutput: r""" past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`): Context values of the time series. For a pretraining task, this denotes the input time series to predict the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly, for classification or regression tasks, it denotes the appropriate context values of the time series. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is greater than 1. target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting, `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target values of the time series, that serve as labels for the model. The `target_values` is what the Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT required for a pretraining task. For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter, pass the target data with all channels, as channel Filtering for both prediction and target will be manually applied before the loss computation. For a classification task, it has a shape of `(batch_size,)`. For a regression task, it has a shape of `(batch_size, num_targets)`. return_loss (`bool`, *optional*): Whether to return the loss in the `forward` call. """ loss = torch.nn.CrossEntropyLoss() return_dict = return_dict if return_dict is not None else self.use_return_dict model_output = self.model( past_values, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # x: [batch_size x nvars x num_patch x d_model] if isinstance(model_output, tuple): model_output = PatchTSMixerModelOutput(*model_output) if self.inject_scale is not None: model_output.last_hidden_state = self.inject_scale( model_output.last_hidden_state, loc=model_output.loc, scale=model_output.scale, ) # x: [batch_size x nvars x num_patch x d_model] y_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x n_labels] if target_values is not None and return_loss is True: loss_val = loss(y_hat, target_values) else: loss_val = None if not return_dict: return tuple( v for v in [ loss_val, y_hat, model_output.last_hidden_state, model_output.hidden_states, ] ) return PatchTSMixerForTimeSeriesClassificationOutput( loss=loss_val, prediction_outputs=y_hat, # tensor [batch_size x n_labels] last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model] hidden_states=model_output.hidden_states, ) @dataclass @auto_docstring( custom_intro=""" Output type of [`PatchTSMixerForRegressionOutput`]. """ ) class PatchTSMixerForRegressionOutput(ModelOutput): r""" loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`): Total loss. regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`): Prediction output from the regression head. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`): Backbone embeddings before passing through the head. hidden_states (`tuple(torch.FloatTensor)`, *optional*): Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. """ loss: Optional[torch.FloatTensor] = None regression_outputs: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None class InjectScalerStatistics4D(nn.Module): def __init__(self, d_model: int, num_patches: int, expansion: int = 2): super().__init__() self.inverse_trans_expansion = nn.Linear(d_model + 2, expansion * d_model) self.inverse_trans_compression = nn.Linear(expansion * d_model, d_model) self.map_scale_expansion = nn.Linear(2, 2 * expansion) self.map_scale_compression = nn.Linear(2 * expansion, 2) self.num_patches = num_patches def forward(self, inputs: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor): """ Args: inputs (`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`) loc (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`) scale (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`) Returns: `torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)` """ mean = loc.transpose(-1, -2) # [batch_size x n_channels x 1 ] mean = mean.unsqueeze(-2) # [batch_size x n_channels x 1 x 1] mean = mean.repeat(1, 1, self.num_patches, 1) # [batch_size x n_channels x num_patch x 1] stdev = scale.transpose(-1, -2) # [batch_size x n_channels x 1 ] stdev = stdev.unsqueeze(-2) # [batch_size x n_channels x 1 x 1] stdev = stdev.repeat(1, 1, self.num_patches, 1) # [batch_size x n_channels x num_patch x 1] concat_stats = torch.cat([mean, stdev], dim=-1) # [batch_size x n_channels x num_patch x 2] concat_stats = self.map_scale_expansion(concat_stats) # [batch_size x n_channels x num_patch x (2*expansion)] concat_stats = self.map_scale_compression(concat_stats) # [batch_size x n_channels x num_patch x 2] inputs = torch.cat([inputs, concat_stats], dim=-1) # [batch_size x channels x num_patch x d_model+2] inputs = self.inverse_trans_expansion(inputs) # [batch_size x channels x num_patch x (expansion*d_model)] inputs = self.inverse_trans_compression(inputs) # [batch_size x channels x num_patch x d_model] return inputs @auto_docstring( custom_intro=""" `PatchTSMixer` for regression application. """ ) class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel): def __init__(self, config: PatchTSMixerConfig): super().__init__(config) self.model = PatchTSMixerModel(config) self.loss = config.loss self.distribution_output = config.distribution_output self.use_return_dict = config.use_return_dict self.num_parallel_samples = config.num_parallel_samples if config.loss == "mse": self.distribution_output = None else: distribution_output_map = { "student_t": StudentTOutput, "normal": NormalOutput, "negative_binomial": NegativeBinomialOutput, } output_class = distribution_output_map.get(config.distribution_output) if output_class is not None: self.distribution_output = output_class(dim=config.num_targets) else: raise ValueError(f"Unknown distribution output {config.distribution_output}") if config.scaling in ["std", "mean", True]: self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches) else: self.inject_scale = None self.head = PatchTSMixerLinearHead( config=config, distribution_output=self.distribution_output, ) # Initialize weights and apply final processing if config.post_init: self.post_init() @auto_docstring def forward( self, past_values: torch.Tensor, target_values: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = False, return_loss: bool = True, return_dict: Optional[bool] = None, ) -> PatchTSMixerForRegressionOutput: r""" past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`): Context values of the time series. For a pretraining task, this denotes the input time series to predict the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly, for classification or regression tasks, it denotes the appropriate context values of the time series. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is greater than 1. target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting, `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target values of the time series, that serve as labels for the model. The `target_values` is what the Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT required for a pretraining task. For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter, pass the target data with all channels, as channel Filtering for both prediction and target will be manually applied before the loss computation. For a classification task, it has a shape of `(batch_size,)`. For a regression task, it has a shape of `(batch_size, num_targets)`. return_loss (`bool`, *optional*): Whether to return the loss in the `forward` call. """ if self.loss == "mse": loss = nn.MSELoss(reduction="mean") elif self.loss == "nll": loss = nll else: raise ValueError("Invalid loss function: Allowed values: mse and nll") return_dict = return_dict if return_dict is not None else self.use_return_dict model_output = self.model( past_values, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # model_output: [batch_size x nvars x num_patch x d_model] if isinstance(model_output, tuple): model_output = PatchTSMixerModelOutput(*model_output) if self.inject_scale is not None: model_output.last_hidden_state = self.inject_scale( model_output.last_hidden_state, loc=model_output.loc, scale=model_output.scale, ) # x: [batch_size x nvars x num_patch x d_model] y_hat = self.head(model_output.last_hidden_state) # [batch_size x num_targets] if target_values is not None and return_loss is True: if self.distribution_output: if self.distribution_output == "negative_binomial" and torch.any(target_values < 0): raise Exception("target_values cannot be negative for negative_binomial distribution.") distribution = self.distribution_output.distribution(y_hat) # y_hat should be a 2-tuple, each with dimension [bs, num_targets] y_hat = tuple([item.view(-1, self.config.num_targets) for item in y_hat]) loss_val = loss(distribution, target_values) # take average of the loss loss_val = weighted_average(loss_val) else: loss_val = loss(y_hat, target_values) else: loss_val = None if not return_dict: return tuple( v for v in [ loss_val, y_hat, model_output.last_hidden_state, model_output.hidden_states, ] ) return PatchTSMixerForRegressionOutput( loss=loss_val, regression_outputs=y_hat, # tensor [batch_size x num_targets] last_hidden_state=model_output.last_hidden_state, # [batch_size x nvars x num_patch x d_model] hidden_states=model_output.hidden_states, ) @torch.no_grad() def generate( self, past_values: torch.Tensor, ) -> SamplePatchTSMixerRegressionOutput: """ Generate sequences of sample predictions from a model with a probability distribution head. Args: past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`): Past values of the time series that serves as context in order to predict the target values. Return: [`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of samples, num_targets)`. """ # get number of samples num_parallel_samples = self.num_parallel_samples # get model output outputs = self( past_values=past_values, target_values=None, output_hidden_states=False, ) # get distribution distribution = self.distribution_output.distribution(outputs.regression_outputs) # get samples samples = [ distribution.sample() for _ in range(num_parallel_samples) ] # samples: list of [batch_size x num_targets] # stack tensors # [batch_size x num_samples x num_targets] samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets) return SamplePatchTSMixerRegressionOutput(sequences=samples) __all__ = [ "PatchTSMixerPreTrainedModel", "PatchTSMixerModel", "PatchTSMixerForPretraining", "PatchTSMixerForPrediction", "PatchTSMixerForTimeSeriesClassification", "PatchTSMixerForRegression", ]