torch_frame.nn.encoder.StypeWiseFeatureEncoder
- class StypeWiseFeatureEncoder(out_channels: int, col_stats: dict[str, dict[torch_frame.data.stats.StatType, Any]], col_names_dict: dict[torch_frame._stype.stype, list[str]], stype_encoder_dict: dict[torch_frame._stype.stype, torch_frame.nn.encoder.stype_encoder.StypeEncoder])[source]
Bases:
FeatureEncoderFeature encoder that transforms each stype tensor into embeddings and performs the final concatenation.
- Parameters:
out_channels (int) – Output dimensionality.
col_stats – (dict[str, dict[
torch_frame.data.stats.StatType, Any]]): A dictionary that maps column name into stats. Available asdataset.col_stats.col_names_dict (dict[
torch_frame.stype, list[str]]) – A dictionary that maps stype to a list of column names. The column names are sorted based on the ordering that appear intensor_frame.feat_dict. Available astensor_frame.col_names_dict.stype_encoder_dict – (dict[
torch_frame.stype,torch_frame.nn.encoder.StypeEncoder]): A dictionary that mapstorch_frame.stypeintotorch_frame.nn.encoder.StypeEncoderclass. Only parentstypesare supported as keys.
- forward(tf: TensorFrame) tuple[torch.Tensor, list[str]][source]
Encode
TensorFrameobject into a tuple(x, col_names).- Parameters:
tf (
torch_frame.TensorFrame) – InputTensorFrameobject.- Returns:
- A tuple of an output column-wise
torch.Tensorof shape[batch_size, num_cols, hidden_channels]and a list of column names ofx. The length needs to benum_cols.
- Return type:
(torch.Tensor, List[str])