torch_frame.nn.encoder.StypeWiseFeatureEncoder
- class StypeWiseFeatureEncoder(out_channels: int, col_stats: dict[str, dict[torch_frame.data.stats.StatType, Any]], col_names_dict: dict[torch_frame._stype.stype, list[str]], stype_encoder_dict: dict[torch_frame._stype.stype, torch_frame.nn.encoder.stype_encoder.StypeEncoder])[source]
Bases:
FeatureEncoder
Feature encoder that transforms each stype tensor into embeddings and performs the final concatenation.
- Parameters:
out_channels (int) – Output dimensionality.
col_stats – (dict[str, dict[
torch_frame.data.stats.StatType
, Any]]): A dictionary that maps column name into stats. Available asdataset.col_stats
.col_names_dict (dict[
torch_frame.stype
, list[str]]) – A dictionary that maps stype to a list of column names. The column names are sorted based on the ordering that appear intensor_frame.feat_dict
. Available astensor_frame.col_names_dict
.stype_encoder_dict – (dict[
torch_frame.stype
,torch_frame.nn.encoder.StypeEncoder
]): A dictionary that mapstorch_frame.stype
intotorch_frame.nn.encoder.StypeEncoder
class. Only parentstypes
are supported as keys.
- forward(tf: TensorFrame) tuple[torch.Tensor, list[str]] [source]
Encode
TensorFrame
object into a tuple(x, col_names)
.- Parameters:
tf (
torch_frame.TensorFrame
) – InputTensorFrame
object.- Returns:
- A tuple of an output column-wise
torch.Tensor
of shape[batch_size, num_cols, hidden_channels]
and a list of column names ofx
. The length needs to benum_cols
.
- Return type:
(torch.Tensor, List[str])