torch_frame.gbdt.GBDT
- class GBDT(task_type: TaskType, num_classes: Optional[int] = None, metric: Optional[Metric] = None)[source]
Bases:
objectBase class for GBDT (Gradient Boosting Decision Trees) models used as strong baseline.
- Parameters:
task_type (TaskType) – The task type.
num_classes (int, optional) – If the task is multiclass classification, an optional num_classes can be used to specify the number of classes. Otherwise, we infer the value from the train data.
metric (Metric, optional) – Metric to optimize for, e.g.,
Metric.MAE. IfNone, it will default toMetric.RMSEfor regression,Metric.ROCAUCfor binary classification, andMetric.ACCURACYfor multi- class classification. (default:None).
- tune(tf_train: TensorFrame, tf_val: TensorFrame, num_trials: int, *args, **kwargs) None[source]
Fit the model by performing hyperparameter tuning using Optuna. The number of trials is specified by num_trials.
- Parameters:
tf_train (TensorFrame) – The train data in
TensorFrame.tf_val (TensorFrame) – The validation data in
TensorFrame.num_trials (int) – Number of trials to perform hyper-parameter search.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- predict(tf_test: TensorFrame) Tensor[source]
Predict the labels/values of the test data on the fitted model and returns its predictions.
TaskType.REGRESSION: Returns raw numerical values.TaskType.BINARY_CLASSIFICATION: Returns the probability of being positive.TaskType.MULTICLASS_CLASSIFICATION: Returns the class label predictions.
- save(path: str) None[source]
Save the model.
- Parameters:
path (str) – The path to save tuned GBDTs model.
- load(path: str) None[source]
Load the model.
- Parameters:
path (str) – The path to load tuned GBDTs model.
- compute_metric(target: Tensor, pred: Tensor) float[source]
Compute evaluation metric given target labels
Tensorand predTensor. Target contains the target values or labels; pred contains the prediction output from calling predict() function.- Returns:
Computed metric score.
- Return type:
score (float)