Source code for torch_frame.transforms.base_transform
from __future__ import annotations
import copy
from abc import ABC, abstractmethod
from typing import Any
from torch_frame import TensorFrame
from torch_frame.data.stats import StatType
[docs]class BaseTransform(ABC):
r"""An abstract base class for writing transforms.
Transforms are a general way to modify and customize
:class:`TensorFrame`
"""
def __init__(self):
self._transformed_stats: dict[str, dict[StatType, Any]] | None = None
def __call__(self, tf: TensorFrame) -> TensorFrame:
# Shallow-copy the data so that we prevent in-place data modification.
return self.forward(copy.copy(tf))
[docs] @abstractmethod
def forward(self, tf: TensorFrame) -> TensorFrame:
r"""Process TensorFrame obj into another TensorFrame obj.
Args:
tf (TensorFrame): Input :class:`TensorFrame`.
Returns:
TensorFrame: Input :class:`TensorFrame` after transform.
"""
return tf
@property
def transformed_stats(self) -> dict[str, dict[StatType, Any]]:
r"""The column stats after the transform.
Returns:
transformed_stats (Dict[str, Dict[StatType, Any]]):
Transformed column stats. The :class:`TensorFrame` object might
be modified by the transform, so the returned
:obj:`transformed_stats` would contain the column stats of the
modified :class:`TensorFrame` object.
"""
if self._transformed_stats is None:
raise ValueError("Transformed column stats is not computed yet. "
"Please run necessary functions to compute this"
" first.")
return self._transformed_stats
def __repr__(self) -> str:
return f'{self.__class__.__name__}()'