torch_frame.nn.models.TabNet
- class TabNet(out_channels: int, num_layers: int, split_feat_channels: int, split_attn_channels: int, gamma: float, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dict: dict[torch_frame.stype, StypeEncoder] | None = None, num_shared_glu_layers: int = 2, num_dependent_glu_layers: int = 2, cat_emb_channels: int = 2)[source]
Bases:
Module
The TabNet model introduced in the “TabNet: Attentive Interpretable Tabular Learning” paper.
Note
For an example of using TabNet, see examples/tabnet.py.
- Parameters:
out_channels (int) – Output dimensionality
num_layers (int) – Number of TabNet layers.
split_feat_channels (int) – Dimensionality of feature channels.
split_attn_channels (int) – Dimensionality of attention channels.
gamma (float) – The gamma value for updating the prior for the attention mask.
col_stats (Dict[str,Dict[torch_frame.data.stats.StatType,Any]]) – A dictionary that maps column name into stats. Available as
dataset.col_stats
.col_names_dict (Dict[torch_frame.stype, List[str]]) – A dictionary that maps
stype
to a list of column names. The column names are sorted based on the ordering that appear intensor_frame.feat_dict
. Available astensor_frame.col_names_dict
.stype_encoder_dict – (dict[
torch_frame.stype
,torch_frame.nn.encoder.StypeEncoder
], optional): A dictionary mapping stypes into their stype encoders. (default:None
, will callEmbeddingEncoder()
for categorical feature andStackEncoder()
for numerical feature)num_shared_glu_layers (int) – Number of GLU layers shared across the
num_layers
FeatureTransformer`s. (default: :obj:`2
)num_dependent_glu_layers (int, optional) – Number of GLU layers to use in each of
num_layers
FeatureTransformer`s. (default: :obj:`2
)cat_emb_channels (int, optional) – The categorical embedding dimensionality.
- forward(tf: TensorFrame, return_reg: bool = False) Tensor | tuple[Tensor, Tensor] [source]
Transform
TensorFrame
object into output embeddings.- Parameters:
tf (TensorFrame) – Input
TensorFrame
object.return_reg (bool) – Whether to return the entropy regularization.
- Returns:
- The output
embeddings of size
[batch_size, out_channels]
. Ifreturn_reg
isTrue
, return the entropy regularization as well.
- Return type:
Union[torch.Tensor, (torch.Tensor, torch.Tensor)]