from __future__ import annotations
import os
from abc import abstractmethod
import torch
from torch import Tensor
from torch_frame import Metric, TaskType, TensorFrame
DEFAULT_METRIC = {
TaskType.REGRESSION: Metric.RMSE,
TaskType.BINARY_CLASSIFICATION: Metric.ROCAUC,
TaskType.MULTICLASS_CLASSIFICATION: Metric.ACCURACY,
}
[docs]class GBDT:
r"""Base class for GBDT (Gradient Boosting Decision Trees) models used as
strong baseline.
Args:
task_type (TaskType): The task type.
num_classes (int, optional): If the task is multiclass classification,
an optional num_classes can be used to specify the number of
classes. Otherwise, we infer the value from the train data.
metric (Metric, optional): Metric to optimize for, e.g.,
:obj:`Metric.MAE`. If :obj:`None`, it will default to
:obj:`Metric.RMSE` for regression, :obj:`Metric.ROCAUC` for binary
classification, and :obj:`Metric.ACCURACY` for multi-
class classification. (default: :obj:`None`).
"""
def __init__(
self,
task_type: TaskType,
num_classes: int | None = None,
metric: Metric | None = None,
):
self.task_type = task_type
self._is_fitted: bool = False
self._num_classes = num_classes
# Set up metric
self.metric = DEFAULT_METRIC[task_type]
if metric is not None:
if metric.supports_task_type(task_type):
self.metric = metric
else:
raise ValueError(
f"{task_type} does not support {metric}. Please choose "
f"from {task_type.supported_metrics}.")
@abstractmethod
def _tune(self, tf_train: TensorFrame, tf_val: TensorFrame,
num_trials: int, *args, **kwargs) -> None:
raise NotImplementedError
@abstractmethod
def _predict(self, tf_train: TensorFrame) -> Tensor:
raise NotImplementedError
@abstractmethod
def _load(self, path: str) -> None:
raise NotImplementedError
@property
def is_fitted(self) -> bool:
r"""Whether the GBDT is already fitted."""
return self._is_fitted
[docs] def tune(
self,
tf_train: TensorFrame,
tf_val: TensorFrame,
num_trials: int,
*args,
**kwargs,
) -> None:
r"""Fit the model by performing hyperparameter tuning using Optuna. The
number of trials is specified by num_trials.
Args:
tf_train (TensorFrame): The train data in :class:`TensorFrame`.
tf_val (TensorFrame): The validation data in :class:`TensorFrame`.
num_trials (int): Number of trials to perform hyper-parameter
search.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
if tf_train.y is None:
raise RuntimeError("tf_train.y must be a Tensor, but None given.")
if tf_val.y is None:
raise RuntimeError("tf_val.y must be a Tensor, but None given.")
self._tune(tf_train, tf_val, *args, num_trials=num_trials, **kwargs)
self._is_fitted = True
[docs] def predict(self, tf_test: TensorFrame) -> Tensor:
r"""Predict the labels/values of the test data on the fitted model and
returns its predictions.
- :obj:`TaskType.REGRESSION`: Returns raw numerical values.
- :obj:`TaskType.BINARY_CLASSIFICATION`: Returns the probability of
being positive.
- :obj:`TaskType.MULTICLASS_CLASSIFICATION`: Returns the class label
predictions.
"""
if not self.is_fitted:
raise RuntimeError(
f"{self.__class__.__name__}' is not yet fitted. Please run "
f"`tune()` first before attempting to predict.")
pred = self._predict(tf_test)
if self.task_type == TaskType.MULTILABEL_CLASSIFICATION:
assert pred.ndim == 2
else:
assert pred.ndim == 1
assert len(pred) == len(tf_test)
return pred
[docs] def save(self, path: str) -> None:
r"""Save the model.
Args:
path (str): The path to save tuned GBDTs model.
"""
if not self.is_fitted:
raise RuntimeError(
f"{self.__class__.__name__} is not yet fitted. Please run "
f"`tune()` first before attempting to save.")
os.makedirs(os.path.dirname(path), exist_ok=True)
self.model.save_model(path)
[docs] def load(self, path: str) -> None:
r"""Load the model.
Args:
path (str): The path to load tuned GBDTs model.
"""
self._load(path)
self._is_fitted = True
[docs] @torch.no_grad()
def compute_metric(
self,
target: Tensor,
pred: Tensor,
) -> float:
r"""Compute evaluation metric given target labels :obj:`Tensor` and
pred :obj:`Tensor`. Target contains the target values or labels; pred
contains the prediction output from calling `predict()` function.
Returns:
score (float): Computed metric score.
"""
if self.metric == Metric.RMSE:
score = (pred - target).square().mean().sqrt().item()
elif self.metric == Metric.MAE:
score = (pred - target).abs().mean().item()
elif self.metric == Metric.ROCAUC:
from sklearn.metrics import roc_auc_score
score = roc_auc_score(target.cpu(), pred.cpu())
elif self.metric == Metric.ACCURACY:
if self.task_type == TaskType.BINARY_CLASSIFICATION:
pred = pred > 0.5
total_correct = (target == pred).sum().item()
test_size = len(target)
score = total_correct / test_size
elif self.metric == Metric.R2:
from sklearn.metrics import r2_score
score = r2_score(target.cpu(), pred.cpu())
else:
raise ValueError(f'{self.metric} is not supported.')
return score