Source code for torch_frame.nn.models.tab_transformer
from __future__ import annotations
import math
from typing import Any
import torch
from torch import Tensor
from torch.nn import (
SELU,
BatchNorm1d,
Embedding,
LayerNorm,
Linear,
Module,
ModuleList,
Sequential,
)
import torch_frame
from torch_frame import TensorFrame, stype
from torch_frame.data.stats import StatType
from torch_frame.nn.conv import TabTransformerConv
from torch_frame.nn.encoder.stype_encoder import EmbeddingEncoder, StackEncoder
from torch_frame.typing import NAStrategy
[docs]class TabTransformer(Module):
r"""The Tab-Transformer model introduced in the
`"TabTransformer: Tabular Data Modeling Using Contextual Embeddings"
<https://arxiv.org/abs/2012.06678>`_ paper.
The model pads a column positional embedding in categorical feature
embeddings and executes multi-layer column-interaction modeling exclusively
on the categorical features. For numerical features, the model simply
applies layer normalization on input features. The model utilizes an
MLP(Multilayer Perceptron) for decoding.
.. note::
For an example of using TabTransformer, see `examples/tabtransformer.py
<https://github.com/pyg-team/pytorch-frame/blob/master/examples/
tabtransformer.py>`_.
Args:
channels (int): Input channel dimensionality.
out_channels (int): Output channels dimensionality.
num_layers (int): Number of convolution layers.
num_heads (int): Number of heads in the self-attention layer.
encoder_pad_size (int): Size of positional encoding padding to the
categorical embeddings.
col_stats(Dict[str,Dict[:class:`torch_frame.data.stats.StatType`,Any]]):
A dictionary that maps column name into stats.
Available as :obj:`dataset.col_stats`.
col_names_dict (Dict[:class:`torch_frame.stype`, List[str]]): A
dictionary that maps stype to a list of column names. The column
names are sorted based on the ordering that appear in
:obj:`tensor_frame.feat_dict`. Available as
:obj:`tensor_frame.col_names_dict`.
"""
def __init__(
self,
channels: int,
out_channels: int,
num_layers: int,
num_heads: int,
encoder_pad_size: int,
attn_dropout: float,
ffn_dropout: float,
col_stats: dict[str, dict[StatType, Any]],
col_names_dict: dict[torch_frame.stype, list[str]],
) -> None:
super().__init__()
if num_layers <= 0:
raise ValueError(
f"num_layers must be a positive integer (got {num_layers})")
self.col_names_dict = col_names_dict
categorical_col_len = 0
numerical_col_len = 0
if stype.categorical in self.col_names_dict:
categorical_stats_list = [
col_stats[col_name]
for col_name in self.col_names_dict[stype.categorical]
]
categorical_col_len = len(self.col_names_dict[stype.categorical])
self.cat_encoder = EmbeddingEncoder(
out_channels=channels - encoder_pad_size,
stats_list=categorical_stats_list,
stype=stype.categorical,
na_strategy=NAStrategy.MOST_FREQUENT,
)
# Use the categorical embedding with EmbeddingEncoder and
# added contextual padding to the end of each feature.
self.pad_embedding = Embedding(categorical_col_len,
encoder_pad_size)
# Apply transformer convolution only over categorical columns
self.tab_transformer_convs = ModuleList([
TabTransformerConv(channels=channels, num_heads=num_heads,
attn_dropout=attn_dropout,
ffn_dropout=ffn_dropout)
for _ in range(num_layers)
])
if stype.numerical in self.col_names_dict:
numerical_stats_list = [
col_stats[col_name]
for col_name in self.col_names_dict[stype.numerical]
]
numerical_col_len = len(self.col_names_dict[stype.numerical])
# Use stack encoder to normalize the numerical columns.
self.num_encoder = StackEncoder(
out_channels=1,
stats_list=numerical_stats_list,
stype=stype.numerical,
)
self.num_norm = LayerNorm(numerical_col_len)
mlp_input_len = categorical_col_len * channels + numerical_col_len
mlp_first_hidden_layer_size = 2 * mlp_input_len
mlp_second_hidden_layer_size = 4 * mlp_input_len
self.decoder = Sequential(
Linear(mlp_input_len, mlp_first_hidden_layer_size),
BatchNorm1d(mlp_first_hidden_layer_size), SELU(),
Linear(2 * mlp_input_len, mlp_second_hidden_layer_size),
BatchNorm1d(mlp_second_hidden_layer_size), SELU(),
Linear(mlp_second_hidden_layer_size, out_channels))
self.reset_parameters()
def reset_parameters(self) -> None:
if stype.categorical in self.col_names_dict:
self.cat_encoder.reset_parameters()
torch.nn.init.normal_(self.pad_embedding.weight, std=0.01)
for tab_transformer_conv in self.tab_transformer_convs:
tab_transformer_conv.reset_parameters()
if stype.numerical in self.col_names_dict:
self.num_encoder.reset_parameters()
self.num_norm.reset_parameters()
for m in self.decoder:
if not isinstance(m, SELU):
m.reset_parameters()
[docs] def forward(self, tf: TensorFrame) -> Tensor:
r"""Transforming :class:`TensorFrame` object into output prediction.
Args:
tf (TensorFrame):
Input :class:`TensorFrame` object.
Returns:
torch.Tensor: Output of shape [batch_size, out_channels].
"""
xs = []
batch_size = len(tf)
if stype.categorical in self.col_names_dict:
x_cat = self.cat_encoder(tf.feat_dict[stype.categorical])
# A positional embedding [batch_size, num_cols, encoder_pad_size]
# is padded to the categorical embedding
# [batch_size, num_cols, channels].
pos_enc_pad = self.pad_embedding.weight.unsqueeze(0).repeat(
batch_size, 1, 1)
# The final categorical embedding is of size [B, num_cols,
# channels + encoder_pad_size]
x_cat = torch.cat((x_cat, pos_enc_pad), dim=-1)
for tab_transformer_conv in self.tab_transformer_convs:
x_cat = tab_transformer_conv(x_cat)
x_cat = x_cat.reshape(batch_size, math.prod(x_cat.shape[1:]))
xs.append(x_cat)
if stype.numerical in self.col_names_dict:
x_num = self.num_encoder(tf.feat_dict[stype.numerical])
x_num = x_num.view(batch_size, math.prod(x_num.shape[1:]))
x_num = self.num_norm(x_num)
xs.append(x_num)
x = torch.cat(xs, dim=1)
out = self.decoder(x)
return out