torch_frame.nn.decoder.TromptDecoder

class TromptDecoder(in_channels: int, out_channels: int, num_prompts: int)[source]

Bases: Decoder

The Trompt downstream introduced in “Trompt: Towards a Better Deep Neural Network for Tabular Data” paper.

Parameters:
  • in_channels (int) – Input channel dimensionality

  • out_channels (int) – Output channel dimensionality

  • num_prompts (int) – Number of prompt columns.

reset_parameters() None[source]

Resets all learnable parameters of the module.

forward(x: Tensor) Tensor[source]

Decode x of shape [batch_size, num_cols, channels] into an output tensor of shape [batch_size, out_channels].

Parameters:
  • x (torch.Tensor) – Input column-wise tensor of shape [batch_size, num_cols, hidden_channels].

  • args (Any) – Extra arguments.

  • kwargs (Any) – Extra keyward arguments.