torch_frame.nn.models.TabNet

class TabNet(out_channels: int, num_layers: int, split_feat_channels: int, split_attn_channels: int, gamma: float, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dict: dict[torch_frame.stype, StypeEncoder] | None = None, num_shared_glu_layers: int = 2, num_dependent_glu_layers: int = 2, cat_emb_channels: int = 2)[source]

Bases: Module

The TabNet model introduced in the “TabNet: Attentive Interpretable Tabular Learning” paper.

Note

For an example of using TabNet, see examples/tabnet.py.

Parameters:
  • out_channels (int) – Output dimensionality

  • num_layers (int) – Number of TabNet layers.

  • split_feat_channels (int) – Dimensionality of feature channels.

  • split_attn_channels (int) – Dimensionality of attention channels.

  • gamma (float) – The gamma value for updating the prior for the attention mask.

  • col_stats (Dict[str,Dict[torch_frame.data.stats.StatType,Any]]) – A dictionary that maps column name into stats. Available as dataset.col_stats.

  • col_names_dict (Dict[torch_frame.stype, List[str]]) – A dictionary that maps stype to a list of column names. The column names are sorted based on the ordering that appear in tensor_frame.feat_dict. Available as tensor_frame.col_names_dict.

  • stype_encoder_dict – (dict[torch_frame.stype, torch_frame.nn.encoder.StypeEncoder], optional): A dictionary mapping stypes into their stype encoders. (default: None, will call EmbeddingEncoder() for categorical feature and StackEncoder() for numerical feature)

  • num_shared_glu_layers (int) – Number of GLU layers shared across the num_layers FeatureTransformer`s. (default: :obj:`2)

  • num_dependent_glu_layers (int, optional) – Number of GLU layers to use in each of num_layers FeatureTransformer`s. (default: :obj:`2)

  • cat_emb_channels (int, optional) – The categorical embedding dimensionality.

forward(tf: TensorFrame, return_reg: bool = False) Tensor | tuple[Tensor, Tensor][source]

Transform TensorFrame object into output embeddings.

Parameters:
  • tf (TensorFrame) – Input TensorFrame object.

  • return_reg (bool) – Whether to return the entropy regularization.

Returns:

The output

embeddings of size [batch_size, out_channels]. If return_reg is True, return the entropy regularization as well.

Return type:

Union[torch.Tensor, (torch.Tensor, torch.Tensor)]