torch_frame.datasets.DataFrameBenchmark
- class DataFrameBenchmark(root: str, task_type: TaskType, scale: str, idx: int, split_random_state: int = 42)[source]
Bases:
DatasetA collection of standardized datasets for tabular learning, covering categorical and numerical features. The datasets are categorized according to their task types and scales.
- Parameters:
root (str) – Root directory.
task_type (TaskType) – The task type. Either
TaskType.BINARY_CLASSIFICATION,TaskType.MULTICLASS_CLASSIFICATION, orTaskType.REGRESSIONscale (str) – The scale of the dataset.
"small"means 5K to 50K rows."medium"means 50K to 500K rows."large"means more than 500K rows.idx (int) – The index of the dataset within a category specified via
task_typeandscale.
STATS:
Task
Scale
Idx
#rows
#cols (numerical)
#cols (categorical)
#classes
Class object
Missing value ratio
binary_classification
small
0
32,561
4
8
2
AdultCensusIncome()
0.0%
binary_classification
small
1
8,124
0
22
2
Mushroom()
0.0%
binary_classification
small
2
45,211
7
9
2
BankMarketing()
0.0%
binary_classification
small
3
13,376
10
0
2
TabularBenchmark(name=’MagicTelescope’)
0.0%
binary_classification
small
4
10,578
7
0
2
TabularBenchmark(name=’bank-marketing’)
0.0%
binary_classification
small
5
20,634
8
0
2
TabularBenchmark(name=’california’)
0.0%
binary_classification
small
6
16,714
10
0
2
TabularBenchmark(name=’credit’)
0.0%
binary_classification
small
7
13,272
20
1
2
TabularBenchmark(name=’default-of-credit-card-clients’)
0.0%
binary_classification
small
8
38,474
7
1
2
TabularBenchmark(name=’electricity’)
0.0%
binary_classification
small
9
7,608
18
5
2
TabularBenchmark(name=’eye_movements’)
0.0%
binary_classification
small
10
10,000
22
0
2
TabularBenchmark(name=’heloc’)
0.0%
binary_classification
small
11
13,488
16
0
2
TabularBenchmark(name=’house_16H’)
0.0%
binary_classification
small
12
10,082
26
0
2
TabularBenchmark(name=’pol’)
0.0%
binary_classification
small
13
48,842
6
8
2
Yandex(name=’adult’)
0.0%
binary_classification
medium
0
92,650
0
116
2
Dota2()
0.0%
binary_classification
medium
1
199,523
7
34
2
KDDCensusIncome()
0.0%
binary_classification
medium
2
71,090
7
0
2
TabularBenchmark(name=’Diabetes130US’)
0.0%
binary_classification
medium
3
72,998
50
0
2
TabularBenchmark(name=’MiniBooNE’)
0.0%
binary_classification
medium
4
58,252
23
8
2
TabularBenchmark(name=’albert’)
0.0%
binary_classification
medium
5
423,680
10
44
2
TabularBenchmark(name=’covertype’)
0.0%
binary_classification
medium
6
57,580
54
0
2
TabularBenchmark(name=’jannis’)
0.0%
binary_classification
medium
7
111,762
24
8
2
TabularBenchmark(name=’road-safety’)
0.0%
binary_classification
medium
8
98,050
28
0
2
Yandex(name=’higgs_small’)
0.0%
binary_classification
large
0
940,160
24
0
2
TabularBenchmark(name=’Higgs’)
0.0%
multiclass_classification
medium
0
108,000
128
0
1,000
Yandex(name=’aloi’)
0.0%
multiclass_classification
medium
1
65,196
27
0
100
Yandex(name=’helena’)
0.0%
multiclass_classification
medium
2
83,733
54
0
4
Yandex(name=’jannis’)
0.0%
multiclass_classification
large
0
581,012
10
44
7
ForestCoverType()
0.0%
multiclass_classification
large
1
1,025,010
5
5
10
PokerHand()
0.0%
multiclass_classification
large
2
581,012
54
0
7
Yandex(name=’covtype’)
0.0%
regression
small
0
17,379
6
5
1
TabularBenchmark(name=’Bike_Sharing_Demand’)
0.0%
regression
small
1
10,692
7
4
1
TabularBenchmark(name=’Brazilian_houses’)
0.0%
regression
small
2
8,192
21
0
1
TabularBenchmark(name=’cpu_act’)
0.0%
regression
small
3
16,599
16
0
1
TabularBenchmark(name=’elevators’)
0.0%
regression
small
4
21,613
15
2
1
TabularBenchmark(name=’house_sales’)
0.0%
regression
small
5
20,640
8
0
1
TabularBenchmark(name=’houses’)
0.0%
regression
small
6
10,081
6
0
1
TabularBenchmark(name=’sulfur’)
0.0%
regression
small
7
21,263
79
0
1
TabularBenchmark(name=’superconduct’)
0.0%
regression
small
8
8,885
252
3
1
TabularBenchmark(name=’topo_2_1’)
0.0%
regression
small
9
8,641
3
1
1
TabularBenchmark(name=’visualizing_soil’)
0.0%
regression
small
10
6,497
11
0
1
TabularBenchmark(name=’wine_quality’)
0.0%
regression
small
11
8,885
42
0
1
TabularBenchmark(name=’yprop_4_1’)
0.0%
regression
small
12
20,640
8
0
1
Yandex(name=’california_housing’)
0.0%
regression
medium
0
188,318
25
99
1
TabularBenchmark(name=’Allstate_Claims_Severity’)
0.0%
regression
medium
1
241,600
3
6
1
TabularBenchmark(name=’SGEMM_GPU_kernel_performance’)
0.0%
regression
medium
2
53,940
6
3
1
TabularBenchmark(name=’diamonds’)
0.0%
regression
medium
3
163,065
3
0
1
TabularBenchmark(name=’medical_charges’)
0.0%
regression
medium
4
394,299
4
2
1
TabularBenchmark(name=’particulate-matter-ukair-2017’)
0.0%
regression
medium
5
52,031
3
1
1
TabularBenchmark(name=’seattlecrime6’)
0.0%
regression
large
0
1,000,000
5
0
1
TabularBenchmark(name=’Airlines_DepDelay_1M’)
0.0%
regression
large
1
5,465,575
8
0
1
TabularBenchmark(name=’delays_zurich_transport’)
0.0%
regression
large
2
581,835
9
0
1
TabularBenchmark(name=’nyc-taxi-green-dec-2016’)
0.0%
regression
large
3
1,200,192
136
0
1
Yandex(name=’microsoft’)
0.0%
regression
large
4
709,877
699
0
1
Yandex(name=’yahoo’)
0.0%
regression
large
5
515,345
90
0
1
Yandex(name=’year’)
0.0%
- classmethod datasets_available(task_type: TaskType, scale: str) list[tuple[str, dict[str, Any]]][source]
List of datasets available for a given
task_typeandscale.
- classmethod num_datasets_available(task_type: TaskType, scale: str)[source]
Number of datasets available for a given
task_typeandscale.
- materialize(*args, **kwargs) Dataset[source]
Materializes the dataset into a tensor representation. From this point onwards, the dataset should be treated as read-only.
- Parameters:
device (torch.device, optional) – Device to load the
TensorFrameobject. (default:None)path (str, optional) – If path is specified and a cached file exists, this will try to load the saved the
TensorFrameobject andcol_stats. Ifpathis specified but a cached file does not exist, this will perform materialization and then save theTensorFrameobject andcol_statstopath. IfpathisNone, this will materialize the dataset without caching. (default:None)col_stats (Dict[str, Dict[StatType, Any]], optional) – optional
provided (col_stats provided by the user. If not) –
statistics (the) –
(default (is calculated from the dataframe itself.) –
None)