from __future__ import annotations
from collections.abc import Mapping
from enum import Enum
from typing import TypeAlias
import pandas as pd
import torch
from torch import Tensor
from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor
from torch_frame.data.multi_nested_tensor import MultiNestedTensor
WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2
WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
WITH_PD3 = int(pd.__version__.split('.')[0]) >= 3
[docs]class Metric(Enum):
r"""The metric.
Attributes:
ACCURACY: accuracy
ROCAUC: rocauc
RMSE: rmse
MAE: mae
"""
ACCURACY = 'accuracy'
ROCAUC = 'rocauc'
RMSE = 'rmse'
MAE = 'mae'
R2 = 'r2'
def supports_task_type(self, task_type: TaskType) -> bool:
return self in task_type.supported_metrics
[docs]class TaskType(Enum):
r"""The type of the task.
Attributes:
REGRESSION: Regression task.
MULTICLASS_CLASSIFICATION: Multi-class classification task.
BINARY_CLASSIFICATION: Binary classification task.
"""
REGRESSION = 'regression'
MULTICLASS_CLASSIFICATION = 'multiclass_classification'
BINARY_CLASSIFICATION = 'binary_classification'
MULTILABEL_CLASSIFICATION = 'multilabel_classification'
@property
def is_classification(self) -> bool:
return self in (TaskType.BINARY_CLASSIFICATION,
TaskType.MULTICLASS_CLASSIFICATION)
@property
def is_regression(self) -> bool:
return self == TaskType.REGRESSION
@property
def supported_metrics(self) -> list[Metric]:
if self == TaskType.REGRESSION:
return [Metric.RMSE, Metric.MAE, Metric.R2]
elif self == TaskType.BINARY_CLASSIFICATION:
return [Metric.ACCURACY, Metric.ROCAUC]
elif self == TaskType.MULTICLASS_CLASSIFICATION:
return [Metric.ACCURACY]
else:
return []
[docs]class NAStrategy(Enum):
r"""Strategy for dealing with NaN values in columns.
Attributes:
MEAN: Replaces NaN values with the mean of a
:obj:`torch_frame.numerical` column.
ZEROS: Replaces NaN values with zeros in a
:obj:`torch_frame.numerical` column.
MOST_FREQUENT: Replaces NaN values with the most frequent category of a
:obj:`torch_frame.categorical` column.
"""
MEAN = 'mean'
MOST_FREQUENT = 'most_frequent'
ZEROS = 'zeros'
OLDEST_TIMESTAMP = 'oldest_timestamp'
NEWEST_TIMESTAMP = 'newest_timestamp'
MEDIAN_TIMESTAMP = 'median_timestamp'
@property
def is_categorical_strategy(self) -> bool:
return self == NAStrategy.MOST_FREQUENT
@property
def is_multicategorical_strategy(self) -> bool:
return self == NAStrategy.ZEROS
@property
def is_numerical_strategy(self) -> bool:
return self in [NAStrategy.MEAN, NAStrategy.ZEROS]
@property
def is_timestamp_strategy(self) -> bool:
return self in [
NAStrategy.NEWEST_TIMESTAMP,
NAStrategy.OLDEST_TIMESTAMP,
NAStrategy.MEDIAN_TIMESTAMP,
]
Series: TypeAlias = pd.Series
DataFrame: TypeAlias = pd.DataFrame
IndexSelectType: TypeAlias = int | list[int] | range | slice | Tensor
ColumnSelectType: TypeAlias = str | list[str]
TextTokenizationMapping: TypeAlias = Mapping[str, Tensor]
TextTokenizationOutputs: TypeAlias = \
list[TextTokenizationMapping] | TextTokenizationMapping
TensorData: TypeAlias = (Tensor | MultiNestedTensor | MultiEmbeddingTensor
| dict[str, MultiNestedTensor])