Source code for torch_frame.nn.conv.trompt_conv

from __future__ import annotations

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import GroupNorm, LayerNorm, Linear, Parameter

from torch_frame.nn.conv import TableConv


[docs]class TromptConv(TableConv): r"""The Trompt cell introduced in the `"Trompt: Towards a Better Deep Neural Network for Tabular Data" <https://arxiv.org/abs/2305.18446>`_ paper. Args: channels (int): Input/output channel dimensionality num_cols (int): Number of columns num_prompts (int): Number of prompt columns. num_groups (int): Number of groups in group norm. (default: :obj:`2`) """ def __init__( self, channels: int, num_cols: int, num_prompts: int, num_groups: int = 2, ): super().__init__() self.channels = channels self.num_cols = num_cols self.num_prompts = num_prompts # Input-independent embeddings for columns and prompts self.embedding_column = Parameter(torch.empty(num_cols, channels)) self.embedding_prompt = Parameter(torch.empty(num_prompts, channels)) # Dense layer for transforming stacked_e_prompt self.lin = Linear(2 * channels, channels) # Dense layer for expanding x self.weight = Parameter(torch.empty(num_prompts)) # Group norm self.group_norm = GroupNorm(num_groups=num_groups, num_channels=num_prompts) # Layer norms self.layer_norm_e_column = LayerNorm(channels) self.layer_norm_e_prompt = LayerNorm(channels) self.reset_parameters()
[docs] def reset_parameters(self): torch.nn.init.normal_(self.embedding_column, std=0.01) torch.nn.init.normal_(self.embedding_prompt, std=0.01) self.lin.reset_parameters() torch.nn.init.normal_(self.weight, std=0.01) self.group_norm.reset_parameters() self.layer_norm_e_column.reset_parameters() self.layer_norm_e_prompt.reset_parameters()
[docs] def forward(self, x: Tensor, x_prompt: Tensor) -> Tensor: r"""Transforms :obj:`x` and :obj:`x_prompt` into :obj:`x_prompt` for the next layer. Args: x (torch.Tensor): Feature-based embedding of shape :obj:`[batch_size, num_cols, channels]` x_prompt (torch.Tensor): Input prompt embeddings of shape :obj:`[batch_size, num_prompts, channels]`. Returns: torch.Tensor: Output prompt embeddings for the next layer. The shape is :obj:`[batch_size, num_prompts, channels]`. """ batch_size = len(x) assert x.shape == (batch_size, self.num_cols, self.channels) assert x_prompt.shape == (batch_size, self.num_prompts, self.channels) # Step1: Compute stacked_e_prompt # [batch_size, num_prompts, channels] embedding_prompt = self.layer_norm_e_prompt(self.embedding_prompt) stacked_e_prompt = embedding_prompt.repeat(batch_size, 1, 1) # [batch_size, num_prompts, 2*channels] stacked_e_prompt_cat = torch.cat([stacked_e_prompt, x_prompt], dim=-1) # [batch_size, num_prompts, channels] stacked_e_prompt = stacked_e_prompt + x_prompt + self.lin( stacked_e_prompt_cat) # Step2: Compute stacked_e_column # [batch_size, num_cols, channels] embedding_column = self.layer_norm_e_column(self.embedding_column) stacked_e_column = embedding_column.repeat(batch_size, 1, 1) # Step 3: Use stacked_e_prompt and stacked_e_column to compute # M_importance # [batch_size, num_prompts, channels], [batch_size, num_cols, channels] # -> [batch_size, num_prompts, num_cols] m_importance = stacked_e_prompt @ stacked_e_column.transpose(1, 2) m_importance = F.softmax(m_importance, dim=-1) # [batch_size, num_prompts, num_cols, 1] m_importance = m_importance.unsqueeze(dim=-1) # Step 4: Expand x ([batch_size, num_cols, channels]) into # shape [batch_size, num_prompts, num_cols, channels] z = torch.einsum('ijl,k->ikjl', x, self.weight) z = F.relu(z) # [batch_size, num_cols, channels] -> # [batch_size, num_prompts, num_cols, channels] x = x.unsqueeze(1).repeat(1, self.num_prompts, 1, 1) # Residual connection x = self.group_norm(z) + x # Step 5: Use x and m_importance to compute x_prompt for the next layer # x * m_importance: shape [batch_size, num_prompts, num_cols, channels] # x_prompt: shape [batch_size, num_prompts, channels] x_prompt = (x * m_importance).sum(dim=2) return x_prompt