torch_frame.nn.models.Trompt
- class Trompt(channels: int, out_channels: int, num_prompts: int, num_layers: int, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dicts: list[dict[torch_frame.stype, StypeEncoder]] | None = None)[source]
Bases:
Module
The Trompt model introduced in the “Trompt: Towards a Better Deep Neural Network for Tabular Data” paper.
Note
For an example of using Trompt, see examples/trompt.py.
- Parameters:
channels (int) – Hidden channel dimensionality
out_channels (int) – Output channels dimensionality
num_prompts (int) – Number of prompt columns.
num_layers (int, optional) – Number of
TromptConv
layers. (default:6
)col_stats (Dict[str,Dict[
torch_frame.data.stats.StatType
,Any]]) – A dictionary that maps column name into stats. Available asdataset.col_stats
.col_names_dict (Dict[
torch_frame.stype
, List[str]]) – A dictionary that maps stype to a list of column names. The column names are sorted based on the ordering that appear intensor_frame.feat_dict
. Available astensor_frame.col_names_dict
.stype_encoder_dicts – (list[dict[
torch_frame.stype
,torch_frame.nn.encoder.StypeEncoder
]], optional): A list ofnum_layers
dictionaries that each dictionary maps stypes into their stype encoders. (default:None
, will callEmbeddingEncoder()
for categorical feature andLinearEncoder()
for numerical feature)
- forward_stacked(tf: TensorFrame) Tensor [source]
Transforming
TensorFrame
object into a series of output predictions at each layer. Used during training to compute layer-wise loss.- Parameters:
tf (
torch_frame.TensorFrame
) – InputTensorFrame
object.- Returns:
- Output predictions stacked across layers. The
shape is
[batch_size, num_layers, out_channels]
.
- Return type:
- forward(tf: TensorFrame) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.