torch_frame.nn.conv.TromptConv

class TromptConv(channels: int, num_cols: int, num_prompts: int, num_groups: int = 2)[source]

Bases: TableConv

The Trompt cell introduced in the “Trompt: Towards a Better Deep Neural Network for Tabular Data” paper.

Parameters:
  • channels (int) – Input/output channel dimensionality

  • num_cols (int) – Number of columns

  • num_prompts (int) – Number of prompt columns.

  • num_groups (int) – Number of groups in group norm. (default: 2)

reset_parameters()[source]

Resets all learnable parameters of the module.

forward(x: Tensor, x_prompt: Tensor) Tensor[source]

Transforms x and x_prompt into x_prompt for the next layer.

Parameters:
  • x (torch.Tensor) – Feature-based embedding of shape [batch_size, num_cols, channels]

  • x_prompt (torch.Tensor) – Input prompt embeddings of shape [batch_size, num_prompts, channels].

Returns:

Output prompt embeddings for the next layer. The

shape is [batch_size, num_prompts, channels].

Return type:

torch.Tensor