Source code for torch_frame.transforms.fittable_base_transform
from __future__ import annotations
import copy
from abc import abstractmethod
from typing import Any
import torch
from torch import Tensor
from torch_frame import NAStrategy, TensorFrame
from torch_frame.data.stats import StatType
from torch_frame.transforms import BaseTransform
[docs]class FittableBaseTransform(BaseTransform):
r"""An abstract base class for writing fittable transforms.
Fittable transforms must be fitted on training data before transform.
"""
def __init__(self):
super().__init__()
self._is_fitted: bool = False
def __call__(self, tf: TensorFrame) -> TensorFrame:
# Shallow-copy the data so that we prevent in-place data modification.
return self.forward(copy.copy(tf))
@property
def is_fitted(self) -> bool:
r"""Whether the transform is already fitted."""
return self._is_fitted
def _replace_nans(self, x: Tensor, na_strategy: NAStrategy):
r"""Replace NaNs based on NAStrategy.
Args:
x (Tensor): Input :class:`Tensor` whose NaN
values in categorical columns will be replaced.
na_strategy (NAStrategy): The :class:`NAStrategy` used to
replace NaN values.
Returns:
Tensor: Output :class:`Tensor` with NaN values
replaced.
"""
x = x.clone()
for col in range(x.size(1)):
column_data = x[:, col]
if na_strategy.is_numerical_strategy:
nan_mask = torch.isnan(column_data)
else:
nan_mask = column_data < 0
if nan_mask.all():
raise ValueError("Column contains only nan values.")
if not nan_mask.any():
continue
valid_data = column_data[~nan_mask]
if na_strategy == NAStrategy.MEAN:
fill_value = valid_data.mean()
elif na_strategy in [NAStrategy.ZEROS, NAStrategy.MOST_FREQUENT]:
fill_value = torch.tensor(0.)
else:
raise ValueError(f'{na_strategy} is not supported.')
column_data[nan_mask] = fill_value
return x
[docs] def fit(
self,
tf: TensorFrame,
col_stats: dict[str, dict[StatType, Any]],
):
r"""Fit the transform with train data.
Args:
tf (TensorFrame): Input :class:`TensorFrame` object representing
the training data.
col_stats (Dict[str, Dict[StatType, Any]], optional): The column
stats of the input :class:`TensorFrame`.
"""
self._fit(tf, col_stats)
self._is_fitted = True
[docs] def forward(self, tf: TensorFrame) -> TensorFrame:
if not self.is_fitted:
raise ValueError(f"'{self.__class__.__name__}' is not yet fitted ."
f"Please run `fit()` first before attempting to "
f"transform the TensorFrame.")
transformed_tf = self._forward(tf)
transformed_tf.validate()
return transformed_tf
@abstractmethod
def _fit(
self,
tf: TensorFrame,
col_stats: dict[str, dict[StatType, Any]],
):
raise NotImplementedError
@abstractmethod
def _forward(self, tf: TensorFrame) -> TensorFrame:
raise NotImplementedError
def state_dict(self) -> dict[str, Any]:
return self.__dict__
def load_state_dict(self, state_dict: dict[str, Any]):
self.__dict__.update(state_dict)
return self