torch_frame.nn.decoder.TromptDecoder
- class TromptDecoder(in_channels: int, out_channels: int, num_prompts: int)[source]
Bases:
DecoderThe Trompt downstream introduced in “Trompt: Towards a Better Deep Neural Network for Tabular Data” paper.
- Parameters:
- forward(x: Tensor) Tensor[source]
Decode
xof shape[batch_size, num_cols, channels]into an output tensor of shape[batch_size, out_channels].- Parameters:
x (torch.Tensor) – Input column-wise tensor of shape
[batch_size, num_cols, hidden_channels].args (Any) – Extra arguments.
kwargs (Any) – Extra keyward arguments.