torch_frame.nn.conv.TableConv

class TableConv(*args, **kwargs)[source]

Bases: Module, ABC

Base class for table convolution that transforms the input column-wise pytorch tensor.

abstract forward(x: Tensor, *args: Any, **kwargs: Any) Any[source]

Process column-wise 3-dimensional tensor into another column-wise 3-dimensional tensor.

Parameters:
  • x (torch.Tensor) – Input column-wise tensor of shape [batch_size, num_cols, hidden_channels].

  • args (Any) – Extra arguments.

  • kwargs (Any) – Extra keyword arguments.

reset_parameters() None[source]

Resets all learnable parameters of the module.