torch_frame.nn.conv.FTTransformerConvs

class FTTransformerConvs(channels: int, feedforward_channels: Optional[int] = None, num_layers: int = 3, nhead: int = 8, dropout: float = 0.2, activation: str = 'relu')[source]

Bases: TableConv

The FT-Transformer backbone in the “Revisiting Deep Learning Models for Tabular Data” paper.

This module concatenates a learnable CLS token embedding x_cls to the input tensor x and applies a multi-layer Transformer on the concatenated tensor. After the Transformer layer, the output tensor is divided into two parts: (1) x, corresponding to the original input tensor, and (2) x_cls, corresponding to the CLS token tensor.

Parameters:
  • channels (int) – Input/output channel dimensionality

  • feedforward_channels (int, optional) – Hidden channels used by feedforward network of the Transformer model. If None, it will be set to channels (default: None)

  • num_layers (int) – Number of transformer encoder layers. (default: 3)

  • nhead (int) – Number of heads in multi-head attention (default: 8)

  • dropout (int) – The dropout value (default: 0.1)

  • activation (str) – The activation function (default: relu)

reset_parameters()[source]

Resets all learnable parameters of the module.

forward(x: Tensor) tuple[torch.Tensor, torch.Tensor][source]

CLS-token augmented Transformer convolution.

Parameters:

x (Tensor) – Input tensor of shape [batch_size, num_cols, channels]

Returns:

(Output tensor of shape [batch_size, num_cols, channels] corresponding to the input columns, Output tensor of shape [batch_size, channels], corresponding to the added CLS token column.)

Return type:

(torch.Tensor, torch.Tensor)