Modular Design of Deep Tabular Models

Our key observation is that many tabular deep learning models all follow a modular design of three components:

  1. FeatureEncoder

  2. TableConv

  3. Decoder

as shown in the figure below:

../_images/modular.png
  • First, the input DataFrame with different columns is converted to TensorFrame, where the columns are organized according to their stype (semantic types such as categorical, numerical and text).

  • Then, the TensorFrame is fed into FeatureEncoder which converts each stype feature into a 3-dimensional Tensor.

  • The Tensors across different stypes are then concatenated into a single Tensor x of shape [batch_size, num_cols, num_channels].

  • The Tensor x is then updated iteratively via TableConvs.

  • The updated Tensor x is given as input to Decoder to produce the output Tensor of shape [batch_size, out_channels].

1. FeatureEncoder

FeatureEncoder transforms input TensorFrame into x, a torch.Tensor of size [batch_size, num_cols, channels]. This class can contain learnable parameters and NaN (missing value) handling.

StypeWiseFeatureEncoder inherits from FeatureEncoder. It takes TensorFrame as input and applies stype-specific feature encoder (specified via stype_encoder_dict) to Tensor of each stype to get embeddings for each stype. The embeddings of different stypes are then concatenated to give the final 3-dimensional Tensor x of shape [batch_size, num_cols, channels].

Note

There exists user-facing and internal types of stypes.

User-facing stypes are declared on the Dataset level, where users can specify the stype for each column in the given DataFrame. The raw data of the user-facing stype will be converted into data of internal stype during materialization. We call the internal stype the parent of the user-facing stype. For instance, stype.text_embedded is a user-facing stype because it declares the semantic type of the raw data stored in DataFrame.

During materialization, we convert the raw data stored as text into embeddings, which makes it no difference from the data stored as stype.embedding. The corresponding semantic type of the column thus becomes stype.embedding in TensorFrame. We consider the stype.embedding as the parent of stype.text_embedded. Only parent semantic types are supported in the stype_encoder_dict. The motivation for this design is that internally, data of the same stype can be grouped together for efficiency.

Below is an example usage of StypeWiseFeatureEncoder consisting of EmbeddingEncoder for encoding stype.categorical columns LinearEmbeddingEncoder for encoding stype.embedding columns, and LinearEncoder for encoding stype.numerical columns.

from torch_frame import stype
from torch_frame.nn import (
    StypeWiseFeatureEncoder,
    EmbeddingEncoder,
    LinearEmbeddingEncoder,
    LinearEncoder,
)

stype_encoder_dict = {
    stype.categorical: EmbeddingEncoder(),
    stype.numerical: LinearEncoder(),
    stype.embedding: LinearEmbeddingEncoder(),
}

encoder = StypeWiseFeatureEncoder(
    out_channels=channels,
    col_stats=col_stats,
    col_names_dict=col_names_dict,
    stype_encoder_dict=stype_encoder_dict,
)

There are other encoders implemented as well such as LinearBucketEncoder and ExcelFormerEncoder for numerical columns. See torch_frame.nn for the full list of built-in encoders.

You can also implement your custom encoder for a given stype by inheriting StypeEncoder.

2. TableConv

The table convolution layer inherits from TableConv. It takes the 3-dimensional Tensor x of shape [batch_size, num_cols, channels] as input and updates the column embeddings based on embeddings of other columns; thereby modeling the complex interactions among different column values. Below, we show a simple self-attention-based table convolution to model the interaction among columns.

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear
from torch_frame.nn import TableConv

class SelfAttentionConv(TableConv):
  def __init__(self, channels: int):
      super().__init__()
      self.channels = channels
      # Linear functions for modeling key/query/value in self-attention.
      self.lin_k = Linear(channels, channels)
      self.lin_q = Linear(channels, channels)
      self.lin_v = Linear(channels, channels)

  def forward(self, x: Tensor) -> Tensor:
      # [batch_size, num_cols, channels]
      x_key = self.lin_k(x)
      x_query = self.lin_q(x)
      x_value = self.lin_v(x)
      prod = x_query.bmm(x_key.transpose(2, 1)) / math.sqrt(self.channels)
      # Attention weights between all pairs of columns.
      attn = F.softmax(prod, dim=-1)
      # Mix `x_value` based on the attention weights
      out = attn.bmm(x_value)
      return out

Initializing and calling it is straightforward.

conv = SelfAttentionConv(32)
x = conv(x)

See torch_frame.nn for the full list of built-in convolution layers.

3. Decoder

Decoder transforms the input Tensor x into out, a Tensor of shape [batch_size, out_channels], representing the row embeddings of the original DataFrame.

Below is a simple example of a Decoder that mean-pools over the column embeddings, followed by a linear transformation.

import torch
from torch import Tensor
from torch.nn import Linear
from torch_frame.nn import Decoder

class MeanDecoder(Decoder):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x: Tensor) -> Tensor:
        # Mean pooling over the column dimension
        # [batch_size, num_cols, in_channels] -> [batch_size, in_channels]
        out = torch.mean(x, dim=1)
        # [batch_size, out_channels]
        return self.lin(out)

See torch_frame.nn for the full list of built-in decoders.