torch_frame.nn.decoder.Decoder
- class Decoder(*args, **kwargs)[source]
-
Base class for decoder that transforms the input column-wise PyTorch tensor into output tensor on which prediction head is applied.
- abstract forward(x: Tensor, *args: Any, **kwargs: Any) Any [source]
Decode
x
of shape[batch_size, num_cols, channels]
into an output tensor of shape[batch_size, out_channels]
.- Parameters:
x (torch.Tensor) – Input column-wise tensor of shape
[batch_size, num_cols, hidden_channels]
.args (Any) – Extra arguments.
kwargs (Any) – Extra keyward arguments.