Source code for torch_frame.nn.models.ft_transformer

from __future__ import annotations

from typing import Any

from torch import Tensor
from torch.nn import LayerNorm, Linear, Module, ReLU, Sequential

import torch_frame
from torch_frame import TensorFrame, stype
from torch_frame.data.stats import StatType
from torch_frame.nn.conv import FTTransformerConvs
from torch_frame.nn.encoder.stype_encoder import (
    EmbeddingEncoder,
    LinearEncoder,
    StypeEncoder,
)
from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder


[docs]class FTTransformer(Module): r"""The FT-Transformer model introduced in the `"Revisiting Deep Learning Models for Tabular Data" <https://arxiv.org/abs/2106.11959>`_ paper. .. note:: For an example of using FTTransformer, see `examples/revisiting.py <https://github.com/pyg-team/pytorch-frame/blob/master/examples/ revisiting.py>`_. Args: channels (int): Hidden channel dimensionality out_channels (int): Output channels dimensionality num_layers (int): Number of layers. (default: :obj:`3`) col_stats(dict[str,dict[:class:`torch_frame.data.stats.StatType`,Any]]): A dictionary that maps column name into stats. Available as :obj:`dataset.col_stats`. col_names_dict (dict[:obj:`torch_frame.stype`, list[str]]): A dictionary that maps stype to a list of column names. The column names are sorted based on the ordering that appear in :obj:`tensor_frame.feat_dict`. Available as :obj:`tensor_frame.col_names_dict`. stype_encoder_dict (dict[:class:`torch_frame.stype`, :class:`torch_frame.nn.encoder.StypeEncoder`], optional): A dictionary mapping stypes into their stype encoders. (default: :obj:`None`, will call :class:`torch_frame.nn.encoder.EmbeddingEncoder()` for categorical feature and :class:`torch_frame.nn.encoder.LinearEncoder()` for numerical feature) """ def __init__( self, channels: int, out_channels: int, num_layers: int, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dict: dict[torch_frame.stype, StypeEncoder] | None = None, ) -> None: super().__init__() if num_layers <= 0: raise ValueError( f"num_layers must be a positive integer (got {num_layers})") if stype_encoder_dict is None: stype_encoder_dict = { stype.categorical: EmbeddingEncoder(), stype.numerical: LinearEncoder(), } self.encoder = StypeWiseFeatureEncoder( out_channels=channels, col_stats=col_stats, col_names_dict=col_names_dict, stype_encoder_dict=stype_encoder_dict, ) self.backbone = FTTransformerConvs(channels=channels, num_layers=num_layers) self.decoder = Sequential( LayerNorm(channels), ReLU(), Linear(channels, out_channels), ) self.reset_parameters() def reset_parameters(self) -> None: self.encoder.reset_parameters() self.backbone.reset_parameters() for m in self.decoder: if not isinstance(m, ReLU): m.reset_parameters()
[docs] def forward(self, tf: TensorFrame) -> Tensor: r"""Transforming :class:`TensorFrame` object into output prediction. Args: tf (TensorFrame): Input :class:`TensorFrame` object. Returns: torch.Tensor: Output of shape [batch_size, out_channels]. """ x, _ = self.encoder(tf) x, x_cls = self.backbone(x) out = self.decoder(x_cls) return out