torch_frame.nn.models.TabTransformer

class TabTransformer(channels: int, out_channels: int, num_layers: int, num_heads: int, encoder_pad_size: int, attn_dropout: float, ffn_dropout: float, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]])[source]

Bases: Module

The Tab-Transformer model introduced in the “TabTransformer: Tabular Data Modeling Using Contextual Embeddings” paper.

The model pads a column positional embedding in categorical feature embeddings and executes multi-layer column-interaction modeling exclusively on the categorical features. For numerical features, the model simply applies layer normalization on input features. The model utilizes an MLP(Multilayer Perceptron) for decoding.

Note

For an example of using TabTransformer, see examples/tabtransformer.py.

Parameters:
  • channels (int) – Input channel dimensionality.

  • out_channels (int) – Output channels dimensionality.

  • num_layers (int) – Number of convolution layers.

  • num_heads (int) – Number of heads in the self-attention layer.

  • encoder_pad_size (int) – Size of positional encoding padding to the categorical embeddings.

  • col_stats (Dict[str,Dict[torch_frame.data.stats.StatType,Any]]) – A dictionary that maps column name into stats. Available as dataset.col_stats.

  • col_names_dict (Dict[torch_frame.stype, List[str]]) – A dictionary that maps stype to a list of column names. The column names are sorted based on the ordering that appear in tensor_frame.feat_dict. Available as tensor_frame.col_names_dict.

forward(tf: TensorFrame) Tensor[source]

Transforming TensorFrame object into output prediction.

Parameters:

tf (TensorFrame) – Input TensorFrame object.

Returns:

Output of shape [batch_size, out_channels].

Return type:

torch.Tensor