torch_frame.nn.models.Trompt

class Trompt(channels: int, out_channels: int, num_prompts: int, num_layers: int, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dicts: list[dict[torch_frame.stype, StypeEncoder]] | None = None)[source]

Bases: Module

The Trompt model introduced in the “Trompt: Towards a Better Deep Neural Network for Tabular Data” paper.

Note

For an example of using Trompt, see examples/trompt.py.

Parameters:
  • channels (int) – Hidden channel dimensionality

  • out_channels (int) – Output channels dimensionality

  • num_prompts (int) – Number of prompt columns.

  • num_layers (int, optional) – Number of TromptConv layers. (default: 6)

  • col_stats (Dict[str,Dict[torch_frame.data.stats.StatType,Any]]) – A dictionary that maps column name into stats. Available as dataset.col_stats.

  • col_names_dict (Dict[torch_frame.stype, List[str]]) – A dictionary that maps stype to a list of column names. The column names are sorted based on the ordering that appear in tensor_frame.feat_dict. Available as tensor_frame.col_names_dict.

  • stype_encoder_dicts – (list[dict[torch_frame.stype, torch_frame.nn.encoder.StypeEncoder]], optional): A list of num_layers dictionaries that each dictionary maps stypes into their stype encoders. (default: None, will call EmbeddingEncoder() for categorical feature and LinearEncoder() for numerical feature)

forward_stacked(tf: TensorFrame) Tensor[source]

Transforming TensorFrame object into a series of output predictions at each layer. Used during training to compute layer-wise loss.

Parameters:

tf (torch_frame.TensorFrame) – Input TensorFrame object.

Returns:

Output predictions stacked across layers. The

shape is [batch_size, num_layers, out_channels].

Return type:

torch.Tensor

forward(tf: TensorFrame) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.