torch_frame.nn.encoder.StypeWiseFeatureEncoder

class StypeWiseFeatureEncoder(out_channels: int, col_stats: dict[str, dict[torch_frame.data.stats.StatType, Any]], col_names_dict: dict[torch_frame._stype.stype, list[str]], stype_encoder_dict: dict[torch_frame._stype.stype, torch_frame.nn.encoder.stype_encoder.StypeEncoder])[source]

Bases: FeatureEncoder

Feature encoder that transforms each stype tensor into embeddings and performs the final concatenation.

Parameters:
forward(tf: TensorFrame) tuple[torch.Tensor, list[str]][source]

Encode TensorFrame object into a tuple (x, col_names).

Parameters:

tf (torch_frame.TensorFrame) – Input TensorFrame object.

Returns:

A tuple of an output column-wise

torch.Tensor of shape [batch_size, num_cols, hidden_channels] and a list of column names of x. The length needs to be num_cols.

Return type:

(torch.Tensor, List[str])