torch_frame.nn.conv.TabTransformerConv

class TabTransformerConv(channels: int, num_heads: int, attn_dropout: float = 0.0, ffn_dropout: float = 0.0)[source]

Bases: TableConv

The TabTransformer Layer introduced in the “TabTransformer: Tabular Data Modeling Using Contextual Embeddings” paper.

Parameters:
  • channels (int) – Input/output channel dimensionality

  • num_heads (int) – Number of attention heads

  • attn_dropout (float) – attention module dropout (default: 0.)

  • ffn_dropout (float) – attention module dropout (default: 0.)

forward(x: Tensor) Tensor[source]

Process column-wise 3-dimensional tensor into another column-wise 3-dimensional tensor.

Parameters:
  • x (torch.Tensor) – Input column-wise tensor of shape [batch_size, num_cols, hidden_channels].

  • args (Any) – Extra arguments.

  • kwargs (Any) – Extra keyword arguments.

reset_parameters()[source]

Resets all learnable parameters of the module.