torch_frame.transforms.FittableBaseTransform
- class FittableBaseTransform[source]
Bases:
BaseTransformAn abstract base class for writing fittable transforms. Fittable transforms must be fitted on training data before transform.
- fit(tf: TensorFrame, col_stats: dict[str, dict[torch_frame.data.stats.StatType, Any]])[source]
Fit the transform with train data.
- Parameters:
tf (TensorFrame) – Input
TensorFrameobject representing the training data.col_stats (Dict[str, Dict[StatType, Any]], optional) – The column stats of the input
TensorFrame.
- forward(tf: TensorFrame) TensorFrame[source]
Process TensorFrame obj into another TensorFrame obj.
- Parameters:
tf (TensorFrame) – Input
TensorFrame.- Returns:
Input
TensorFrameafter transform.- Return type: