Source code for torch_frame.nn.conv.ft_transformer_convs

from __future__ import annotations

import torch
from torch import Tensor
from torch.nn import (
    LayerNorm,
    Parameter,
    TransformerEncoder,
    TransformerEncoderLayer,
)

from torch_frame.nn.conv import TableConv


[docs]class FTTransformerConvs(TableConv): r"""The FT-Transformer backbone in the `"Revisiting Deep Learning Models for Tabular Data" <https://arxiv.org/abs/2106.11959>`_ paper. This module concatenates a learnable CLS token embedding :obj:`x_cls` to the input tensor :obj:`x` and applies a multi-layer Transformer on the concatenated tensor. After the Transformer layer, the output tensor is divided into two parts: (1) :obj:`x`, corresponding to the original input tensor, and (2) :obj:`x_cls`, corresponding to the CLS token tensor. Args: channels (int): Input/output channel dimensionality feedforward_channels (int, optional): Hidden channels used by feedforward network of the Transformer model. If :obj:`None`, it will be set to :obj:`channels` (default: :obj:`None`) num_layers (int): Number of transformer encoder layers. (default: 3) nhead (int): Number of heads in multi-head attention (default: 8) dropout (int): The dropout value (default: 0.1) activation (str): The activation function (default: :obj:`relu`) """ def __init__( self, channels: int, feedforward_channels: int | None = None, # Arguments for Transformer num_layers: int = 3, nhead: int = 8, dropout: float = 0.2, activation: str = 'relu', ): super().__init__() encoder_layer = TransformerEncoderLayer( d_model=channels, nhead=nhead, dim_feedforward=feedforward_channels or channels, dropout=dropout, activation=activation, # Input and output tensors are provided as # [batch_size, seq_len, channels] batch_first=True, ) encoder_norm = LayerNorm(channels) self.transformer = TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_layers, norm=encoder_norm) self.cls_embedding = Parameter(torch.empty(channels)) self.reset_parameters()
[docs] def reset_parameters(self): torch.nn.init.normal_(self.cls_embedding, std=0.01) for p in self.transformer.parameters(): if p.dim() > 1: torch.nn.init.xavier_uniform_(p)
[docs] def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: r"""CLS-token augmented Transformer convolution. Args: x (Tensor): Input tensor of shape [batch_size, num_cols, channels] Returns: (torch.Tensor, torch.Tensor): (Output tensor of shape [batch_size, num_cols, channels] corresponding to the input columns, Output tensor of shape [batch_size, channels], corresponding to the added CLS token column.) """ B, _, _ = x.shape # [batch_size, num_cols, channels] x_cls = self.cls_embedding.repeat(B, 1, 1) # [batch_size, num_cols + 1, channels] x_concat = torch.cat([x_cls, x], dim=1) # [batch_size, num_cols + 1, channels] x_concat = self.transformer(x_concat) x_cls, x = x_concat[:, 0, :], x_concat[:, 1:, :] return x, x_cls