torch_frame.nn.models.TabNet
- class TabNet(out_channels: int, num_layers: int, split_feat_channels: int, split_attn_channels: int, gamma: float, col_stats: dict[str, dict[torch_frame.data.stats.StatType, Any]], col_names_dict: dict[torch_frame._stype.stype, list[str]], stype_encoder_dict: Optional[dict[torch_frame._stype.stype, torch_frame.nn.encoder.stype_encoder.StypeEncoder]] = None, num_shared_glu_layers: int = 2, num_dependent_glu_layers: int = 2, cat_emb_channels: int = 2)[source]
Bases:
ModuleThe TabNet model introduced in the “TabNet: Attentive Interpretable Tabular Learning” paper.
Note
For an example of using TabNet, see examples/tabnet.py.
- Parameters:
out_channels (int) – Output dimensionality
num_layers (int) – Number of TabNet layers.
split_feat_channels (int) – Dimensionality of feature channels.
split_attn_channels (int) – Dimensionality of attention channels.
gamma (float) – The gamma value for updating the prior for the attention mask.
col_stats (Dict[str,Dict[torch_frame.data.stats.StatType,Any]]) – A dictionary that maps column name into stats. Available as
dataset.col_stats.col_names_dict (Dict[torch_frame.stype, List[str]]) – A dictionary that maps
stypeto a list of column names. The column names are sorted based on the ordering that appear intensor_frame.feat_dict. Available astensor_frame.col_names_dict.stype_encoder_dict – (dict[
torch_frame.stype,torch_frame.nn.encoder.StypeEncoder], optional): A dictionary mapping stypes into their stype encoders. (default:None, will callEmbeddingEncoder()for categorical feature andStackEncoder()for numerical feature)num_shared_glu_layers (int) – Number of GLU layers shared across the
num_layersFeatureTransformer`s. (default: :obj:`2)num_dependent_glu_layers (int, optional) – Number of GLU layers to use in each of
num_layersFeatureTransformer`s. (default: :obj:`2)cat_emb_channels (int, optional) – The categorical embedding dimensionality.
- forward(tf: TensorFrame, return_reg: bool = False) torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]
Transform
TensorFrameobject into output embeddings.- Parameters:
tf (TensorFrame) – Input
TensorFrameobject.return_reg (bool) – Whether to return the entropy regularization.
- Returns:
- The output
embeddings of size
[batch_size, out_channels]. Ifreturn_regisTrue, return the entropy regularization as well.
- Return type:
Union[torch.Tensor, (torch.Tensor, torch.Tensor)]