torch_frame.nn.models.FTTransformer

class FTTransformer(channels: int, out_channels: int, num_layers: int, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dict: dict[torch_frame.stype, StypeEncoder] | None = None)[source]

Bases: Module

The FT-Transformer model introduced in the “Revisiting Deep Learning Models for Tabular Data” paper.

Note

For an example of using FTTransformer, see examples/revisiting.py.

Parameters:
forward(tf: TensorFrame) Tensor[source]

Transforming TensorFrame object into output prediction.

Parameters:

tf (TensorFrame) – Input TensorFrame object.

Returns:

Output of shape [batch_size, out_channels].

Return type:

torch.Tensor