Source code for torch_frame.nn.models.excelformer
from __future__ import annotations
from typing import Any
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module, ModuleList
import torch_frame
from torch_frame import stype
from torch_frame.data.stats import StatType
from torch_frame.data.tensor_frame import TensorFrame
from torch_frame.nn.conv import ExcelFormerConv
from torch_frame.nn.decoder import ExcelFormerDecoder
from torch_frame.nn.encoder.stype_encoder import ExcelFormerEncoder
from torch_frame.nn.encoder.stypewise_encoder import (
StypeEncoder,
StypeWiseFeatureEncoder,
)
from torch_frame.typing import NAStrategy
def feature_mixup(
x: Tensor,
y: Tensor,
num_classes: int,
beta: float | Tensor = 0.5,
mixup_type: str | None = None,
mi_scores: Tensor | None = None,
) -> tuple[Tensor, Tensor]:
r"""Mixup input numerical feature tensor :obj:`x` by swapping some
feature elements of two shuffled sample samples. The shuffle rates for
each row is sampled from the Beta distribution. The target `y` is also
linearly mixed up.
Args:
x (Tensor): The input numerical feature.
y (Tensor): The target.
num_classes (int): Number of classes.
beta (float): The concentration parameter of the Beta distribution.
(default: :obj:`0.5`)
mixup_type (str, optional): The mixup methods. No mixup if set to
:obj:`None`, options `feature` and `hidden` are `FEAT-MIX`
(mixup at feature dimension) and `HIDDEN-MIX` (mixup at
hidden dimension) proposed in ExcelFormer paper.
(default: :obj:`None`)
mi_scores (Tensor, optional): Mutual information scores only used in
the mixup weight calculation for `FEAT-MIX`.
(default: :obj:`None`)
Returns:
x_mixedup (Tensor): The mixedup numerical feature.
y_mixedup (Tensor): Transformed target of size
:obj:`[batch_size, num_classes]`
"""
assert num_classes > 0
assert mixup_type in [None, 'feature', 'hidden']
beta = torch.tensor(beta, dtype=x.dtype, device=x.device)
beta_distribution = torch.distributions.beta.Beta(beta, beta)
shuffle_rates = beta_distribution.sample(torch.Size((len(x), 1)))
shuffled_idx = torch.randperm(len(x), device=x.device)
assert x.ndim == 3, """
FEAT-MIX or HIDDEN-MIX is for encoded numerical features
of size [batch_size, num_cols, in_channels]."""
b, f, d = x.shape
if mixup_type == 'feature':
assert mi_scores is not None
mi_scores = mi_scores.to(x.device)
# Hard mask (feature dimension)
mixup_mask = torch.rand(torch.Size((b, f)),
device=x.device) < shuffle_rates
# L1 normalized mutual information scores
norm_mi_scores = mi_scores / mi_scores.sum()
# Mixup weights
lam = torch.sum(
norm_mi_scores.unsqueeze(0) * mixup_mask, dim=1, keepdim=True)
mixup_mask = mixup_mask.unsqueeze(2)
elif mixup_type == 'hidden':
# Hard mask (hidden dimension)
mixup_mask = torch.rand(torch.Size((b, d)),
device=x.device) < shuffle_rates
mixup_mask = mixup_mask.unsqueeze(1)
# Mixup weights
lam = shuffle_rates
else:
# No mixup
mixup_mask = torch.ones_like(x, dtype=torch.bool)
# Fake mixup weights
lam = torch.ones_like(shuffle_rates)
x_mixedup = mixup_mask * x + ~mixup_mask * x[shuffled_idx]
y_shuffled = y[shuffled_idx]
if num_classes == 1:
# Regression task or binary classification
lam = lam.squeeze(1)
y_mixedup = lam * y + (1 - lam) * y_shuffled
else:
# Classification task
one_hot_y = F.one_hot(y, num_classes=num_classes)
one_hot_y_shuffled = F.one_hot(y_shuffled, num_classes=num_classes)
y_mixedup = (lam * one_hot_y + (1 - lam) * one_hot_y_shuffled)
return x_mixedup, y_mixedup
[docs]class ExcelFormer(Module):
r"""The ExcelFormer model introduced in the
`"ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data"
<https://arxiv.org/abs/2301.02819>`_ paper.
ExcelFormer first converts the categorical features with a target
statistics encoder (i.e., :class:`CatBoostEncoder` in the paper)
into numerical features. Then it sorts the numerical features
with mutual information sort. So the model itself limits to
numerical features.
.. note::
For an example of using ExcelFormer, see `examples/excelformer.py
<https://github.com/pyg-team/pytorch-frame/blob/master/examples/
excelformer.py>`_.
Args:
in_channels (int): Input channel dimensionality
out_channels (int): Output channels dimensionality
num_cols (int): Number of columns
num_layers (int): Number of
:class:`torch_frame.nn.conv.ExcelFormerConv` layers.
num_heads (int): Number of attention heads used in :class:`DiaM`
col_stats(dict[str,dict[:class:`torch_frame.data.stats.StatType`,Any]]):
A dictionary that maps column name into stats.
Available as :obj:`dataset.col_stats`.
col_names_dict (dict[:obj:`torch_frame.stype`, list[str]]): A
dictionary that maps stype to a list of column names. The column
names are sorted based on the ordering that appear in
:obj:`tensor_frame.feat_dict`. Available as
:obj:`tensor_frame.col_names_dict`.
stype_encoder_dict
(dict[:class:`torch_frame.stype`,
:class:`torch_frame.nn.encoder.StypeEncoder`], optional):
A dictionary mapping stypes into their stype encoders.
(default: :obj:`None`, will call :obj:`ExcelFormerEncoder()`
for numerical feature)
diam_dropout (float, optional): diam_dropout. (default: :obj:`0.0`)
aium_dropout (float, optional): aium_dropout. (default: :obj:`0.0`)
residual_dropout (float, optional): residual dropout.
(default: :obj:`0.0`)
mixup (str, optional): mixup type.
:obj:`None`, :obj:`feature`, or :obj:`hidden`.
(default: :obj:`None`)
beta (float, optional): Shape parameter for beta distribution to
calculate shuffle rate in mixup. Only useful when `mixup` is
not :obj:`None`. (default: :obj:`0.5`)
"""
def __init__(
self,
in_channels: int,
out_channels: int,
num_cols: int,
num_layers: int,
num_heads: int,
col_stats: dict[str, dict[StatType, Any]],
col_names_dict: dict[torch_frame.stype, list[str]],
stype_encoder_dict: dict[torch_frame.stype, StypeEncoder]
| None = None,
diam_dropout: float = 0.0,
aium_dropout: float = 0.0,
residual_dropout: float = 0.0,
mixup: str | None = None,
beta: float = 0.5,
) -> None:
super().__init__()
if num_layers <= 0:
raise ValueError(
f"num_layers must be a positive integer (got {num_layers})")
assert mixup in [None, 'feature', 'hidden']
self.in_channels = in_channels
self.out_channels = out_channels
if col_names_dict.keys() != {stype.numerical}:
raise ValueError("ExcelFormer only accepts numerical "
"features.")
if stype_encoder_dict is None:
stype_encoder_dict = {
stype.numerical:
ExcelFormerEncoder(out_channels, na_strategy=NAStrategy.MEAN)
}
self.excelformer_encoder = StypeWiseFeatureEncoder(
out_channels=self.in_channels,
col_stats=col_stats,
col_names_dict=col_names_dict,
stype_encoder_dict=stype_encoder_dict,
)
self.excelformer_convs = ModuleList([
ExcelFormerConv(in_channels, num_cols, num_heads, diam_dropout,
aium_dropout, residual_dropout)
for _ in range(num_layers)
])
self.excelformer_decoder = ExcelFormerDecoder(in_channels,
out_channels, num_cols)
self.reset_parameters()
self.mixup = mixup
self.beta = beta
def reset_parameters(self) -> None:
self.excelformer_encoder.reset_parameters()
for excelformer_conv in self.excelformer_convs:
excelformer_conv.reset_parameters()
self.excelformer_decoder.reset_parameters()
[docs] def forward(
self,
tf: TensorFrame,
mixup_encoded: bool = False,
) -> Tensor | tuple[Tensor, Tensor]:
r"""Transform :class:`TensorFrame` object into output embeddings. If
:obj:`mixup_encoded` is :obj:`True`, it produces the output embeddings
together with the mixed-up targets in :obj:`self.mixup` manner.
Args:
tf (:class:`torch_frame.TensorFrame`): Input :class:`TensorFrame`
object.
mixup_encoded (bool): Whether to mixup on encoded numerical
features, i.e., `FEAT-MIX` and `HIDDEN-MIX`.
(default: :obj:`False`)
Returns:
torch.Tensor | tuple[Tensor, Tensor]: The output embeddings of size
[batch_size, out_channels]. If :obj:`mixup_encoded` is
:obj:`True`, return the mixed-up targets of size
[batch_size, num_classes] as well.
"""
x, _ = self.excelformer_encoder(tf)
# FEAT-MIX or HIDDEN-MIX is compatible with `torch.compile`
if mixup_encoded:
assert tf.y is not None
x, y_mixedup = feature_mixup(
x,
tf.y,
num_classes=self.out_channels,
beta=self.beta,
mixup_type=self.mixup,
mi_scores=getattr(tf, 'mi_scores', None),
)
for excelformer_conv in self.excelformer_convs:
x = excelformer_conv(x)
out = self.excelformer_decoder(x)
if mixup_encoded:
return out, y_mixedup
return out