Modular Design of Deep Tabular Models
Our key observation is that many tabular deep learning models all follow a modular design of three components:
as shown in the figure below:
First, the input
DataFramewith different columns is converted toTensorFrame, where the columns are organized according to theirstype(semantic types such as categorical, numerical and text).Then, the
TensorFrameis fed intoFeatureEncoderwhich converts eachstypefeature into a 3-dimensionalTensor.The
Tensorsacross differentstypesare then concatenated into a singleTensorxof shape[batch_size, num_cols, num_channels].The
Tensorxis then updated iteratively viaTableConvs.The updated
Tensorxis given as input toDecoderto produce the outputTensorof shape[batch_size, out_channels].
1. FeatureEncoder
FeatureEncoder transforms input TensorFrame into x, a torch.Tensor of size [batch_size, num_cols, channels].
This class can contain learnable parameters and NaN (missing value) handling.
StypeWiseFeatureEncoder inherits from FeatureEncoder.
It takes TensorFrame as input and applies stype-specific feature encoder (specified via stype_encoder_dict) to Tensor of each stype to get embeddings for each stype.
The embeddings of different stypes are then concatenated to give the final 3-dimensional Tensor x of shape [batch_size, num_cols, channels].
Note
There exists user-facing and internal types of stypes.
User-facing stypes are declared on the Dataset level, where users can specify the stype for each column in the given DataFrame.
The raw data of the user-facing stype will be converted into data of internal stype during materialization.
We call the internal stype the parent of the user-facing stype.
For instance, stype.text_embedded is a user-facing stype because it declares the semantic type of the raw data stored in DataFrame.
During materialization, we convert the raw data stored as text into embeddings, which makes it no difference from the data stored as stype.embedding.
The corresponding semantic type of the column thus becomes stype.embedding in TensorFrame.
We consider the stype.embedding as the parent of stype.text_embedded.
Only parent semantic types are supported in the stype_encoder_dict.
The motivation for this design is that internally, data of the same stype can be grouped together for efficiency.
Below is an example usage of StypeWiseFeatureEncoder consisting of
EmbeddingEncoder for encoding stype.categorical columns
LinearEmbeddingEncoder for encoding stype.embedding columns,
and LinearEncoder for encoding stype.numerical columns.
from torch_frame import stype
from torch_frame.nn import (
StypeWiseFeatureEncoder,
EmbeddingEncoder,
LinearEmbeddingEncoder,
LinearEncoder,
)
stype_encoder_dict = {
stype.categorical: EmbeddingEncoder(),
stype.numerical: LinearEncoder(),
stype.embedding: LinearEmbeddingEncoder(),
}
encoder = StypeWiseFeatureEncoder(
out_channels=channels,
col_stats=col_stats,
col_names_dict=col_names_dict,
stype_encoder_dict=stype_encoder_dict,
)
There are other encoders implemented as well such as LinearBucketEncoder and ExcelFormerEncoder for numerical columns.
See torch_frame.nn for the full list of built-in encoders.
You can also implement your custom encoder for a given stype by inheriting StypeEncoder.
2. TableConv
The table convolution layer inherits from TableConv.
It takes the 3-dimensional Tensor x of shape [batch_size, num_cols, channels] as input and
updates the column embeddings based on embeddings of other columns; thereby modeling the complex interactions among different column values.
Below, we show a simple self-attention-based table convolution to model the interaction among columns.
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear
from torch_frame.nn import TableConv
class SelfAttentionConv(TableConv):
def __init__(self, channels: int):
super().__init__()
self.channels = channels
# Linear functions for modeling key/query/value in self-attention.
self.lin_k = Linear(channels, channels)
self.lin_q = Linear(channels, channels)
self.lin_v = Linear(channels, channels)
def forward(self, x: Tensor) -> Tensor:
# [batch_size, num_cols, channels]
x_key = self.lin_k(x)
x_query = self.lin_q(x)
x_value = self.lin_v(x)
prod = x_query.bmm(x_key.transpose(2, 1)) / math.sqrt(self.channels)
# Attention weights between all pairs of columns.
attn = F.softmax(prod, dim=-1)
# Mix `x_value` based on the attention weights
out = attn.bmm(x_value)
return out
Initializing and calling it is straightforward.
conv = SelfAttentionConv(32)
x = conv(x)
See torch_frame.nn for the full list of built-in convolution layers.
3. Decoder
Decoder transforms the input Tensor x into out, a Tensor of shape [batch_size, out_channels], representing
the row embeddings of the original DataFrame.
Below is a simple example of a Decoder that mean-pools over the column embeddings, followed by a linear transformation.
import torch
from torch_frame.nn import Decoder
class MeanDecoder(Decoder):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Mean pooling over the column dimension
# [batch_size, num_cols, in_channels] -> [batch_size, in_channels]
out = torch.mean(x, dim=1)
# [batch_size, out_channels]
return self.lin(out)
See torch_frame.nn for the full list of built-in decoders.