torch_frame.nn.conv.TromptConv
- class TromptConv(channels: int, num_cols: int, num_prompts: int, num_groups: int = 2)[source]
Bases:
TableConvThe Trompt cell introduced in the “Trompt: Towards a Better Deep Neural Network for Tabular Data” paper.
- Parameters:
- forward(x: Tensor, x_prompt: Tensor) Tensor[source]
Transforms
xandx_promptintox_promptfor the next layer.- Parameters:
x (torch.Tensor) – Feature-based embedding of shape
[batch_size, num_cols, channels]x_prompt (torch.Tensor) – Input prompt embeddings of shape
[batch_size, num_prompts, channels].
- Returns:
- Output prompt embeddings for the next layer. The
shape is
[batch_size, num_prompts, channels].
- Return type: