torch_frame.nn.decoder.Decoder

class Decoder(*args, **kwargs)[source]

Bases: Module, ABC

Base class for decoder that transforms the input column-wise PyTorch tensor into output tensor on which prediction head is applied.

abstract forward(x: Tensor, *args: Any, **kwargs: Any) Any[source]

Decode x of shape [batch_size, num_cols, channels] into an output tensor of shape [batch_size, out_channels].

Parameters:
  • x (torch.Tensor) – Input column-wise tensor of shape [batch_size, num_cols, hidden_channels].

  • args (Any) – Extra arguments.

  • kwargs (Any) – Extra keyward arguments.

reset_parameters() None[source]

Resets all learnable parameters of the module.