torch_frame.transforms.FittableBaseTransform

class FittableBaseTransform[source]

Bases: BaseTransform

An abstract base class for writing fittable transforms. Fittable transforms must be fitted on training data before transform.

property is_fitted: bool

Whether the transform is already fitted.

fit(tf: TensorFrame, col_stats: dict[str, dict[torch_frame.data.stats.StatType, Any]])[source]

Fit the transform with train data.

Parameters:
  • tf (TensorFrame) – Input TensorFrame object representing the training data.

  • col_stats (Dict[str, Dict[StatType, Any]], optional) – The column stats of the input TensorFrame.

forward(tf: TensorFrame) TensorFrame[source]

Process TensorFrame obj into another TensorFrame obj.

Parameters:

tf (TensorFrame) – Input TensorFrame.

Returns:

Input TensorFrame after transform.

Return type:

TensorFrame