torch_frame.nn.conv.FTTransformerConvs
- class FTTransformerConvs(channels: int, feedforward_channels: Optional[int] = None, num_layers: int = 3, nhead: int = 8, dropout: float = 0.2, activation: str = 'relu')[source]
Bases:
TableConvThe FT-Transformer backbone in the “Revisiting Deep Learning Models for Tabular Data” paper.
This module concatenates a learnable CLS token embedding
x_clsto the input tensorxand applies a multi-layer Transformer on the concatenated tensor. After the Transformer layer, the output tensor is divided into two parts: (1)x, corresponding to the original input tensor, and (2)x_cls, corresponding to the CLS token tensor.- Parameters:
channels (int) – Input/output channel dimensionality
feedforward_channels (int, optional) – Hidden channels used by feedforward network of the Transformer model. If
None, it will be set tochannels(default:None)num_layers (int) – Number of transformer encoder layers. (default: 3)
nhead (int) – Number of heads in multi-head attention (default: 8)
dropout (int) – The dropout value (default: 0.1)
activation (str) – The activation function (default:
relu)
- forward(x: Tensor) tuple[torch.Tensor, torch.Tensor][source]
CLS-token augmented Transformer convolution.
- Parameters:
x (Tensor) – Input tensor of shape [batch_size, num_cols, channels]
- Returns:
(Output tensor of shape [batch_size, num_cols, channels] corresponding to the input columns, Output tensor of shape [batch_size, channels], corresponding to the added CLS token column.)
- Return type: