torch_frame.nn.encoder.LinearEmbeddingEncoder

class LinearEmbeddingEncoder(out_channels: Optional[int] = None, stats_list: Optional[list[dict[torch_frame.data.stats.StatType, Any]]] = None, stype: Optional[stype] = None, post_module: Optional[Module] = None, na_strategy: Optional[NAStrategy] = None)[source]

Bases: StypeEncoder

Linear function based encoder for pre-computed embedding features. It applies a linear layer torch.nn.Linear(emb_dim, out_channels) on each embedding feature and concatenates the output embeddings.

reset_parameters() None[source]

Initialize the parameters of post_module.

encode_forward(feat: MultiEmbeddingTensor, col_names: Optional[list[str]] = None) Tensor[source]

The main forward function. Maps input feat from TensorFrame (shape [batch_size, num_cols]) into output x of shape [batch_size, num_cols, out_channels].