Source code for torch_frame.data.stats

from __future__ import annotations

from enum import Enum
from typing import Any

import numpy as np
import pandas as pd
import pandas.api.types as ptypes
import torch

import torch_frame
from torch_frame.data.mapper import (
    MultiCategoricalTensorMapper,
    TimestampTensorMapper,
)
from torch_frame.typing import Series


def _flatten_numeric(values: np.ndarray) -> np.ndarray:
    # Fast path: already a 1D numeric array (e.g., the `numerical` stype).
    # Avoids the expensive `np.hstack` call used for nested sequences.
    if values.dtype != object and values.ndim == 1:
        return values
    return np.hstack(values)


[docs]class StatType(Enum): r"""The different types for column statistics. Attributes: MEAN: The average value of a numerical column. STD: The standard deviation of a numerical column. QUANTILES: The minimum, first quartile, median, third quartile, and the maximum of a numerical column. COUNT: The count of each category in a categorical column. MULTI_COUNT: The count of each category in a multi-categorical column. YEAR_RANGE: The range of years in a timestamp column. """ # Numerical: MEAN = "MEAN" STD = "STD" QUANTILES = "QUANTILES" # categorical: COUNT = "COUNT" # multicategorical: MULTI_COUNT = "MULTI_COUNT" # timestamp YEAR_RANGE = "YEAR_RANGE" OLDEST_TIME = "OLDEST_TIME" NEWEST_TIME = "NEWEST_TIME" MEDIAN_TIME = "MEDIAN_TIME" # text_embedded (Also, embedding) # Note: For text_embedded, this stats is computed in # dataset._update_col_stats, not here. EMB_DIM = "EMB_DIM" @staticmethod def stats_for_stype(stype: torch_frame.stype) -> list[StatType]: stats_type = { torch_frame.numerical: [ StatType.MEAN, StatType.STD, StatType.QUANTILES, ], torch_frame.categorical: [StatType.COUNT], torch_frame.multicategorical: [StatType.MULTI_COUNT], torch_frame.sequence_numerical: [ StatType.MEAN, StatType.STD, StatType.QUANTILES, ], torch_frame.timestamp: [ StatType.YEAR_RANGE, StatType.NEWEST_TIME, StatType.OLDEST_TIME, StatType.MEDIAN_TIME, ], torch_frame.embedding: [ StatType.EMB_DIM, ] } return stats_type.get(stype, []) def compute( self, ser: Series, sep: str | None = None, ) -> Any: if self == StatType.MEAN: flattened = _flatten_numeric(ser.values) finite_mask = np.isfinite(flattened) if not finite_mask.any(): # NOTE: We may just error out here if eveything is NaN return np.nan return np.mean(flattened[finite_mask]).item() elif self == StatType.STD: flattened = _flatten_numeric(ser.values) finite_mask = np.isfinite(flattened) if not finite_mask.any(): return np.nan return np.std(flattened[finite_mask]).item() elif self == StatType.QUANTILES: flattened = _flatten_numeric(ser.values) finite_mask = np.isfinite(flattened) if not finite_mask.any(): return [np.nan, np.nan, np.nan, np.nan, np.nan] return np.quantile( flattened[finite_mask], q=[0, 0.25, 0.5, 0.75, 1], ).tolist() elif self == StatType.COUNT: count = ser.value_counts(ascending=False) return count.index.tolist(), count.values.tolist() elif self == StatType.MULTI_COUNT: ser = ser.apply(lambda row: MultiCategoricalTensorMapper. split_by_sep(row, sep)) ser = ser.explode().dropna() count = ser.value_counts(ascending=False) return count.index.tolist(), count.values.tolist() elif self == StatType.YEAR_RANGE: year_range = ser.dt.year.values return [min(year_range), max(year_range)] elif self == StatType.NEWEST_TIME: return TimestampTensorMapper.to_tensor(pd.Series( ser.iloc[-1])).squeeze(0) elif self == StatType.OLDEST_TIME: return TimestampTensorMapper.to_tensor(pd.Series( ser.iloc[0])).squeeze(0) elif self == StatType.MEDIAN_TIME: return TimestampTensorMapper.to_tensor( pd.Series(ser.iloc[len(ser) // 2])).squeeze(0) elif self == StatType.EMB_DIM: return len(ser[0])
_default_values = { StatType.MEAN: np.nan, StatType.STD: np.nan, StatType.QUANTILES: [np.nan, np.nan, np.nan, np.nan, np.nan], StatType.COUNT: ([], []), StatType.MULTI_COUNT: ([], []), StatType.YEAR_RANGE: [-1, -1], StatType.NEWEST_TIME: torch.tensor([-1, -1, -1, -1, -1, -1, -1]), StatType.OLDEST_TIME: torch.tensor([-1, -1, -1, -1, -1, -1, -1]), StatType.MEDIAN_TIME: torch.tensor([-1, -1, -1, -1, -1, -1, -1]), StatType.EMB_DIM: -1, } def compute_col_stats( ser: Series, stype: torch_frame.stype, sep: str | None = None, time_format: str | None = None, ) -> dict[StatType, Any]: if stype == torch_frame.numerical: ser = ser.mask(ser.isin([np.inf, -np.inf]), np.nan) if not ptypes.is_numeric_dtype(ser): raise TypeError("Numerical series contains invalid entries. " "Please make sure your numerical series " "contains only numerical values or nans.") if ser.isnull().all(): # NOTE: We may just error out here if eveything is NaN stats = { stat_type: _default_values[stat_type] for stat_type in StatType.stats_for_stype(stype) } else: if stype == torch_frame.timestamp: ser = pd.to_datetime(ser, format=time_format) ser = ser.sort_values() stats = { stat_type: stat_type.compute(ser.dropna(), sep) for stat_type in StatType.stats_for_stype(stype) } return stats