Source code for torch_frame.nn.decoder.excelformer_decoder

import torch
from torch import Tensor
from torch.nn import Linear, PReLU

from torch_frame.nn.decoder import Decoder


[docs]class ExcelFormerDecoder(Decoder): r"""The ExcelFormer decoder introduced in the `"ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data" <https://arxiv.org/abs/2301.02819>`_ paper. Args: in_channels (int): Input channel dimensionality out_channels (int): Output channel dimensionality num_cols (int): Number of columns. """ def __init__( self, in_channels: int, out_channels: int, num_cols: int, ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.lin_f = Linear(num_cols, self.out_channels) self.activation = PReLU() self.lin_d = Linear(self.in_channels, 1) self.reset_parameters()
[docs] def reset_parameters(self) -> None: self.lin_f.reset_parameters() self.lin_d.reset_parameters() with torch.no_grad(): self.activation.weight.fill_(0.25)
[docs] def forward(self, x: Tensor) -> Tensor: r"""Transforming :obj:`x` into output predictions. Args: x (Tensor): Input column-wise tensor of shape [batch_size, num_cols, in_channels] Returns: Tensor: [batch_size, out_channels]. """ x = x.transpose(1, 2) x = self.lin_f(x) x = self.activation(x) x = self.lin_d(x.transpose(1, 2)).squeeze(2) return x