Source code for torch_frame.datasets.data_frame_benchmark

from __future__ import annotations

from typing import Any

import pandas as pd

import torch_frame
from torch_frame.typing import TaskType
from torch_frame.utils import generate_random_split

SPLIT_COL = 'split'


[docs]class DataFrameBenchmark(torch_frame.data.Dataset): r"""A collection of standardized datasets for tabular learning, covering categorical and numerical features. The datasets are categorized according to their task types and scales. Args: root (str): Root directory. task_type (TaskType): The task type. Either :obj:`TaskType.BINARY_CLASSIFICATION`, :obj:`TaskType.MULTICLASS_CLASSIFICATION`, or :obj:`TaskType.REGRESSION` scale (str): The scale of the dataset. :obj:`"small"` means 5K to 50K rows. :obj:`"medium"` means 50K to 500K rows. :obj:`"large"` means more than 500K rows. idx (int): The index of the dataset within a category specified via :obj:`task_type` and :obj:`scale`. **STATS:** .. list-table:: :widths: 20 10 10 10 10 10 20 20 10 :header-rows: 1 * - Task - Scale - Idx - #rows - #cols (numerical) - #cols (categorical) - #classes - Class object - Missing value ratio * - binary_classification - small - 0 - 32,561 - 4 - 8 - 2 - AdultCensusIncome() - 0.0% * - binary_classification - small - 1 - 8,124 - 0 - 22 - 2 - Mushroom() - 0.0% * - binary_classification - small - 2 - 45,211 - 7 - 9 - 2 - BankMarketing() - 0.0% * - binary_classification - small - 3 - 13,376 - 10 - 0 - 2 - TabularBenchmark(name='MagicTelescope') - 0.0% * - binary_classification - small - 4 - 10,578 - 7 - 0 - 2 - TabularBenchmark(name='bank-marketing') - 0.0% * - binary_classification - small - 5 - 20,634 - 8 - 0 - 2 - TabularBenchmark(name='california') - 0.0% * - binary_classification - small - 6 - 16,714 - 10 - 0 - 2 - TabularBenchmark(name='credit') - 0.0% * - binary_classification - small - 7 - 13,272 - 20 - 1 - 2 - TabularBenchmark(name='default-of-credit-card-clients') - 0.0% * - binary_classification - small - 8 - 38,474 - 7 - 1 - 2 - TabularBenchmark(name='electricity') - 0.0% * - binary_classification - small - 9 - 7,608 - 18 - 5 - 2 - TabularBenchmark(name='eye_movements') - 0.0% * - binary_classification - small - 10 - 10,000 - 22 - 0 - 2 - TabularBenchmark(name='heloc') - 0.0% * - binary_classification - small - 11 - 13,488 - 16 - 0 - 2 - TabularBenchmark(name='house_16H') - 0.0% * - binary_classification - small - 12 - 10,082 - 26 - 0 - 2 - TabularBenchmark(name='pol') - 0.0% * - binary_classification - small - 13 - 48,842 - 6 - 8 - 2 - Yandex(name='adult') - 0.0% * - binary_classification - medium - 0 - 92,650 - 0 - 116 - 2 - Dota2() - 0.0% * - binary_classification - medium - 1 - 199,523 - 7 - 34 - 2 - KDDCensusIncome() - 0.0% * - binary_classification - medium - 2 - 71,090 - 7 - 0 - 2 - TabularBenchmark(name='Diabetes130US') - 0.0% * - binary_classification - medium - 3 - 72,998 - 50 - 0 - 2 - TabularBenchmark(name='MiniBooNE') - 0.0% * - binary_classification - medium - 4 - 58,252 - 23 - 8 - 2 - TabularBenchmark(name='albert') - 0.0% * - binary_classification - medium - 5 - 423,680 - 10 - 44 - 2 - TabularBenchmark(name='covertype') - 0.0% * - binary_classification - medium - 6 - 57,580 - 54 - 0 - 2 - TabularBenchmark(name='jannis') - 0.0% * - binary_classification - medium - 7 - 111,762 - 24 - 8 - 2 - TabularBenchmark(name='road-safety') - 0.0% * - binary_classification - medium - 8 - 98,050 - 28 - 0 - 2 - Yandex(name='higgs_small') - 0.0% * - binary_classification - large - 0 - 940,160 - 24 - 0 - 2 - TabularBenchmark(name='Higgs') - 0.0% * - multiclass_classification - medium - 0 - 108,000 - 128 - 0 - 1,000 - Yandex(name='aloi') - 0.0% * - multiclass_classification - medium - 1 - 65,196 - 27 - 0 - 100 - Yandex(name='helena') - 0.0% * - multiclass_classification - medium - 2 - 83,733 - 54 - 0 - 4 - Yandex(name='jannis') - 0.0% * - multiclass_classification - large - 0 - 581,012 - 10 - 44 - 7 - ForestCoverType() - 0.0% * - multiclass_classification - large - 1 - 1,025,010 - 5 - 5 - 10 - PokerHand() - 0.0% * - multiclass_classification - large - 2 - 581,012 - 54 - 0 - 7 - Yandex(name='covtype') - 0.0% * - regression - small - 0 - 17,379 - 6 - 5 - 1 - TabularBenchmark(name='Bike_Sharing_Demand') - 0.0% * - regression - small - 1 - 10,692 - 7 - 4 - 1 - TabularBenchmark(name='Brazilian_houses') - 0.0% * - regression - small - 2 - 8,192 - 21 - 0 - 1 - TabularBenchmark(name='cpu_act') - 0.0% * - regression - small - 3 - 16,599 - 16 - 0 - 1 - TabularBenchmark(name='elevators') - 0.0% * - regression - small - 4 - 21,613 - 15 - 2 - 1 - TabularBenchmark(name='house_sales') - 0.0% * - regression - small - 5 - 20,640 - 8 - 0 - 1 - TabularBenchmark(name='houses') - 0.0% * - regression - small - 6 - 10,081 - 6 - 0 - 1 - TabularBenchmark(name='sulfur') - 0.0% * - regression - small - 7 - 21,263 - 79 - 0 - 1 - TabularBenchmark(name='superconduct') - 0.0% * - regression - small - 8 - 8,885 - 252 - 3 - 1 - TabularBenchmark(name='topo_2_1') - 0.0% * - regression - small - 9 - 8,641 - 3 - 1 - 1 - TabularBenchmark(name='visualizing_soil') - 0.0% * - regression - small - 10 - 6,497 - 11 - 0 - 1 - TabularBenchmark(name='wine_quality') - 0.0% * - regression - small - 11 - 8,885 - 42 - 0 - 1 - TabularBenchmark(name='yprop_4_1') - 0.0% * - regression - small - 12 - 20,640 - 8 - 0 - 1 - Yandex(name='california_housing') - 0.0% * - regression - medium - 0 - 188,318 - 25 - 99 - 1 - TabularBenchmark(name='Allstate_Claims_Severity') - 0.0% * - regression - medium - 1 - 241,600 - 3 - 6 - 1 - TabularBenchmark(name='SGEMM_GPU_kernel_performance') - 0.0% * - regression - medium - 2 - 53,940 - 6 - 3 - 1 - TabularBenchmark(name='diamonds') - 0.0% * - regression - medium - 3 - 163,065 - 3 - 0 - 1 - TabularBenchmark(name='medical_charges') - 0.0% * - regression - medium - 4 - 394,299 - 4 - 2 - 1 - TabularBenchmark(name='particulate-matter-ukair-2017') - 0.0% * - regression - medium - 5 - 52,031 - 3 - 1 - 1 - TabularBenchmark(name='seattlecrime6') - 0.0% * - regression - large - 0 - 1,000,000 - 5 - 0 - 1 - TabularBenchmark(name='Airlines_DepDelay_1M') - 0.0% * - regression - large - 1 - 5,465,575 - 8 - 0 - 1 - TabularBenchmark(name='delays_zurich_transport') - 0.0% * - regression - large - 2 - 581,835 - 9 - 0 - 1 - TabularBenchmark(name='nyc-taxi-green-dec-2016') - 0.0% * - regression - large - 3 - 1,200,192 - 136 - 0 - 1 - Yandex(name='microsoft') - 0.0% * - regression - large - 4 - 709,877 - 699 - 0 - 1 - Yandex(name='yahoo') - 0.0% * - regression - large - 5 - 515,345 - 90 - 0 - 1 - Yandex(name='year') - 0.0% """ dataset_categorization_dict: dict[str, dict[str, list[tuple]]] = { 'binary_classification': { 'small': [ ('AdultCensusIncome', {}), ('Mushroom', {}), ('BankMarketing', {}), ('TabularBenchmark', { 'name': 'MagicTelescope' }), ('TabularBenchmark', { 'name': 'bank-marketing' }), ('TabularBenchmark', { 'name': 'california' }), ('TabularBenchmark', { 'name': 'credit' }), ('TabularBenchmark', { 'name': 'default-of-credit-card-clients' }), ('TabularBenchmark', { 'name': 'electricity' }), ('TabularBenchmark', { 'name': 'eye_movements' }), ('TabularBenchmark', { 'name': 'heloc' }), ('TabularBenchmark', { 'name': 'house_16H' }), ('TabularBenchmark', { 'name': 'pol' }), ('Yandex', { 'name': 'adult' }), ], 'medium': [ ('Dota2', {}), ('KDDCensusIncome', {}), ('TabularBenchmark', { 'name': 'Diabetes130US' }), ('TabularBenchmark', { 'name': 'MiniBooNE' }), ('TabularBenchmark', { 'name': 'albert' }), ('TabularBenchmark', { 'name': 'covertype' }), ('TabularBenchmark', { 'name': 'jannis' }), ('TabularBenchmark', { 'name': 'road-safety' }), ('Yandex', { 'name': 'higgs_small' }), ], 'large': [ ('TabularBenchmark', { 'name': 'Higgs' }), ] }, 'multiclass_classification': { 'small': [], 'medium': [ ('Yandex', { 'name': 'aloi' }), ('Yandex', { 'name': 'helena' }), ('Yandex', { 'name': 'jannis' }), ], 'large': [ ('ForestCoverType', {}), ('PokerHand', {}), ('Yandex', { 'name': 'covtype' }), ] }, 'regression': { 'small': [ ('TabularBenchmark', { 'name': 'Bike_Sharing_Demand' }), ('TabularBenchmark', { 'name': 'Brazilian_houses' }), ('TabularBenchmark', { 'name': 'cpu_act' }), ('TabularBenchmark', { 'name': 'elevators' }), ('TabularBenchmark', { 'name': 'house_sales' }), ('TabularBenchmark', { 'name': 'houses' }), ('TabularBenchmark', { 'name': 'sulfur' }), ('TabularBenchmark', { 'name': 'superconduct' }), ('TabularBenchmark', { 'name': 'topo_2_1' }), ('TabularBenchmark', { 'name': 'visualizing_soil' }), ('TabularBenchmark', { 'name': 'wine_quality' }), ('TabularBenchmark', { 'name': 'yprop_4_1' }), ('Yandex', { 'name': 'california_housing' }), ], 'medium': [ ('TabularBenchmark', { 'name': 'Allstate_Claims_Severity' }), ('TabularBenchmark', { 'name': 'SGEMM_GPU_kernel_performance' }), ('TabularBenchmark', { 'name': 'diamonds' }), ('TabularBenchmark', { 'name': 'medical_charges' }), ('TabularBenchmark', { 'name': 'particulate-matter-ukair-2017' }), ('TabularBenchmark', { 'name': 'seattlecrime6' }), ], 'large': [ ('TabularBenchmark', { 'name': 'Airlines_DepDelay_1M' }), ('TabularBenchmark', { 'name': 'delays_zurich_transport' }), ('TabularBenchmark', { 'name': 'nyc-taxi-green-dec-2016' }), ('Yandex', { 'name': 'microsoft' }), ('Yandex', { 'name': 'yahoo' }), ('Yandex', { 'name': 'year' }), ] } }
[docs] @classmethod def datasets_available( cls, task_type: TaskType, scale: str, ) -> list[tuple[str, dict[str, Any]]]: r"""List of datasets available for a given :obj:`task_type` and :obj:`scale`. """ return cls.dataset_categorization_dict[task_type.value][scale]
[docs] @classmethod def num_datasets_available(cls, task_type: TaskType, scale: str): r"""Number of datasets available for a given :obj:`task_type` and :obj:`scale`. """ return len(cls.datasets_available(task_type, scale))
def __init__( self, root: str, task_type: TaskType, scale: str, idx: int, split_random_state: int = 42, ): self.root = root self._task_type = task_type self.scale = scale self.idx = idx datasets = self.datasets_available(task_type, scale) if idx >= len(datasets): raise ValueError( f"The idx needs to be smaller than {len(datasets)}, which is " f"the number of available datasets for task_type: " f"{task_type.value} and scale: {scale} (got idx: {idx}).") class_name, kwargs = self.datasets_available(task_type, scale)[idx] dataset = getattr(torch_frame.datasets, class_name)(root=root, **kwargs) self.cls_str = str(dataset) # Add split col df = dataset.df if SPLIT_COL in df.columns: df.drop(columns=[SPLIT_COL], inplace=True) split_df = pd.DataFrame({ SPLIT_COL: generate_random_split(length=len(df), seed=split_random_state, train_ratio=0.8, val_ratio=0.1) }) df = pd.concat([df, split_df], axis=1) # For regression task, we normalize the target. if task_type == TaskType.REGRESSION: ser = df[dataset.target_col] df[dataset.target_col] = (ser - ser.mean()) / ser.std() # Check the scale if dataset.num_rows < 5000: raise AssertionError() elif dataset.num_rows < 50000: assert scale == "small" elif dataset.num_rows < 500000: assert scale == "medium" else: assert scale == "large" super().__init__(df=df, col_to_stype=dataset.col_to_stype, target_col=dataset.target_col, split_col=SPLIT_COL) del dataset def __repr__(self) -> str: return (f'{self.__class__.__name__}(\n' f' task_type={self._task_type.value},\n' f' scale={self.scale},\n' f' idx={self.idx},\n' f' cls={self.cls_str}\n' f')')
[docs] def materialize(self, *args, **kwargs) -> torch_frame.data.Dataset: super().materialize(*args, **kwargs) if self.task_type != self._task_type: raise RuntimeError(f"task type does not match. It should be " f"{self.task_type.value} but specified as " f"{self._task_type.value}.") return self