torch_frame.gbdt.GBDT

class GBDT(task_type: TaskType, num_classes: Optional[int] = None, metric: Optional[Metric] = None)[source]

Bases: object

Base class for GBDT (Gradient Boosting Decision Trees) models used as strong baseline.

Parameters:
  • task_type (TaskType) – The task type.

  • num_classes (int, optional) – If the task is multiclass classification, an optional num_classes can be used to specify the number of classes. Otherwise, we infer the value from the train data.

  • metric (Metric, optional) – Metric to optimize for, e.g., Metric.MAE. If None, it will default to Metric.RMSE for regression, Metric.ROCAUC for binary classification, and Metric.ACCURACY for multi- class classification. (default: None).

property is_fitted: bool

Whether the GBDT is already fitted.

tune(tf_train: TensorFrame, tf_val: TensorFrame, num_trials: int, *args, **kwargs) None[source]

Fit the model by performing hyperparameter tuning using Optuna. The number of trials is specified by num_trials.

Parameters:
  • tf_train (TensorFrame) – The train data in TensorFrame.

  • tf_val (TensorFrame) – The validation data in TensorFrame.

  • num_trials (int) – Number of trials to perform hyper-parameter search.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

predict(tf_test: TensorFrame) Tensor[source]

Predict the labels/values of the test data on the fitted model and returns its predictions.

  • TaskType.REGRESSION: Returns raw numerical values.

  • TaskType.BINARY_CLASSIFICATION: Returns the probability of being positive.

  • TaskType.MULTICLASS_CLASSIFICATION: Returns the class label predictions.

save(path: str) None[source]

Save the model.

Parameters:

path (str) – The path to save tuned GBDTs model.

load(path: str) None[source]

Load the model.

Parameters:

path (str) – The path to load tuned GBDTs model.

compute_metric(target: Tensor, pred: Tensor) float[source]

Compute evaluation metric given target labels Tensor and pred Tensor. Target contains the target values or labels; pred contains the prediction output from calling predict() function.

Returns:

Computed metric score.

Return type:

score (float)