Modular Design of Deep Tabular Models
Our key observation is that many tabular deep learning models all follow a modular design of three components:
as shown in the figure below:

First, the input
DataFrame
with different columns is converted toTensorFrame
, where the columns are organized according to theirstype
(semantic types such as categorical, numerical and text).Then, the
TensorFrame
is fed intoFeatureEncoder
which converts eachstype
feature into a 3-dimensionalTensor
.The
Tensors
across differentstypes
are then concatenated into a singleTensor
x
of shape[batch_size, num_cols, num_channels]
.The
Tensor
x
is then updated iteratively viaTableConvs
.The updated
Tensor
x
is given as input toDecoder
to produce the outputTensor
of shape[batch_size, out_channels]
.
1. FeatureEncoder
FeatureEncoder
transforms input TensorFrame
into x
, a torch.Tensor
of size [batch_size, num_cols, channels]
.
This class can contain learnable parameters and NaN (missing value) handling.
StypeWiseFeatureEncoder
inherits from FeatureEncoder
.
It takes TensorFrame
as input and applies stype-specific feature encoder (specified via stype_encoder_dict
) to Tensor
of each stype to get embeddings for each stype
.
The embeddings of different stypes
are then concatenated to give the final 3-dimensional Tensor
x
of shape [batch_size, num_cols, channels]
.
Note
There exists user-facing and internal types of stypes
.
User-facing stypes
are declared on the Dataset
level, where users can specify the stype
for each column in the given DataFrame
.
The raw data of the user-facing stype
will be converted into data of internal stype
during materialization.
We call the internal stype
the parent of the user-facing stype
.
For instance, stype.text_embedded
is a user-facing stype
because it declares the semantic type of the raw data stored in DataFrame
.
During materialization, we convert the raw data stored as text into embeddings, which makes it no difference from the data stored as stype.embedding
.
The corresponding semantic type of the column thus becomes stype.embedding
in TensorFrame
.
We consider the stype.embedding
as the parent of stype.text_embedded
.
Only parent semantic types are supported in the stype_encoder_dict
.
The motivation for this design is that internally, data of the same stype
can be grouped together for efficiency.
Below is an example usage of StypeWiseFeatureEncoder
consisting of
EmbeddingEncoder
for encoding stype.categorical
columns
LinearEmbeddingEncoder
for encoding stype.embedding
columns,
and LinearEncoder
for encoding stype.numerical
columns.
from torch_frame import stype
from torch_frame.nn import (
StypeWiseFeatureEncoder,
EmbeddingEncoder,
LinearEmbeddingEncoder,
LinearEncoder,
)
stype_encoder_dict = {
stype.categorical: EmbeddingEncoder(),
stype.numerical: LinearEncoder(),
stype.embedding: LinearEmbeddingEncoder(),
}
encoder = StypeWiseFeatureEncoder(
out_channels=channels,
col_stats=col_stats,
col_names_dict=col_names_dict,
stype_encoder_dict=stype_encoder_dict,
)
There are other encoders implemented as well such as LinearBucketEncoder
and ExcelFormerEncoder
for numerical
columns.
See torch_frame.nn
for the full list of built-in encoders.
You can also implement your custom encoder for a given stype
by inheriting StypeEncoder
.
2. TableConv
The table convolution layer inherits from TableConv
.
It takes the 3-dimensional Tensor
x
of shape [batch_size, num_cols, channels]
as input and
updates the column embeddings based on embeddings of other columns; thereby modeling the complex interactions among different column values.
Below, we show a simple self-attention-based table convolution to model the interaction among columns.
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear
from torch_frame.nn import TableConv
class SelfAttentionConv(TableConv):
def __init__(self, channels: int):
super().__init__()
self.channels = channels
# Linear functions for modeling key/query/value in self-attention.
self.lin_k = Linear(channels, channels)
self.lin_q = Linear(channels, channels)
self.lin_v = Linear(channels, channels)
def forward(self, x: Tensor) -> Tensor:
# [batch_size, num_cols, channels]
x_key = self.lin_k(x)
x_query = self.lin_q(x)
x_value = self.lin_v(x)
prod = x_query.bmm(x_key.transpose(2, 1)) / math.sqrt(self.channels)
# Attention weights between all pairs of columns.
attn = F.softmax(prod, dim=-1)
# Mix `x_value` based on the attention weights
out = attn.bmm(x_value)
return out
Initializing and calling it is straightforward.
conv = SelfAttentionConv(32)
x = conv(x)
See torch_frame.nn
for the full list of built-in convolution layers.
3. Decoder
Decoder
transforms the input Tensor
x
into out
, a Tensor
of shape [batch_size, out_channels]
, representing
the row embeddings of the original DataFrame
.
Below is a simple example of a Decoder
that mean-pools over the column embeddings, followed by a linear transformation.
import torch
from torch import Tensor
from torch.nn import Linear
from torch_frame.nn import Decoder
class MeanDecoder(Decoder):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x: Tensor) -> Tensor:
# Mean pooling over the column dimension
# [batch_size, num_cols, in_channels] -> [batch_size, in_channels]
out = torch.mean(x, dim=1)
# [batch_size, out_channels]
return self.lin(out)
See torch_frame.nn
for the full list of built-in decoders.