Introduction by Example
PyTorch Frame is a tabular deep learning extension library for PyTorch. Modern data is stored in a table format with heterogeneous columns each with its own semantic type, e.g., numerical (such as age or price), categorical (such as gender or product type), time, text (such as descriptions or comments), images, etc. The goal of PyTorch Frame is to build a deep learning framework to perform effective machine learning on such complex and diverse data.
Many recent tabular models follow the modular design of FeatureEncoder
, TableConv
, and Decoder
.
PyTorch Frame is designed to facilitate the creation, implementation and evaluation of deep learning models for tabular data under such modular architecture.
Please refer to the Modular Design of Deep Tabular Models page for more information.
In this doc, we introduce the fundamental concepts of PyTorch Frame through self-contained examples.
At its core, PyTorch Frame provides the following main features:
Common Benchmark Datasets
PyTorch Frame contains a large number of common benchmark datasets. The list of all datasets are available in torch_frame.datasets.
Initializing datasets is straightforward in PyTorch Frame. An initialization of a dataset will automatically download its raw files and process the columns.
In the below example, we will use one of the pre-loaded datasets, containing the Titanic passengers. If you would like to use your own dataset, refer to the example in Handling Heterogeneous Semantic Types.
from torch_frame.datasets import Titanic
dataset = Titanic(root='/tmp/titanic')
len(dataset)
>>> 891
dataset.feat_cols
>>> ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']
dataset.materialize()
>>> Titanic()
dataset.df.head(5)
>>>
Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
PassengerId
1 0 3 Braund, Mr. Owen Harris male 22.0 1 0 A/5 21171 7.2500 NaN S
2 1 1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 0 PC 17599 71.2833 C85 C
3 1 3 Heikkinen, Miss. Laina female 26.0 0 0 STON/O2. 3101282 7.9250 NaN S
4 1 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 0 113803 53.1000 C123 S
5 0 3 Allen, Mr. William Henry male 35.0 0 0 373450 8.0500 NaN S
PyTorch Frame also supports a custom dataset, so that you can use PyTorch Frame for your own problem.
Let’s say you prepare your pandas.DataFrame
as df
with five columns:
cat1
, cat2
, num1
, num2
, and y
.
Creating torch_frame.data.Dataset
object is very easy:
import torch_frame
from torch_frame.data import Dataset
# Specify the stype of each column with a dictionary.
col_to_stype = {"cat1": torch_frame.categorical, "cat2": torch_frame.categorical,
"num1": torch_frame.numerical, "num2": torch_frame.numerical,
"y": torch_frame.categorical}
# Set "y" as the target column.
dataset = Dataset(df, col_to_stype=col_to_stype, target_col="y")
Data Handling of Tables
A table contains different columns with different data types. Each data type is described by a semantic type which we refer to as stype
.
Currently PyTorch Frame supports the following stypes
:
stype.categorical
denotes categorical columns.stype.numerical
denotes numerical columns.stype.multicategorical
denotes multi_categorical columns.stype.text_embedded
denotes text columns that are pre-embedded via some text encoder.
A table in PyTorch Frame is described by an instance of TensorFrame
, which holds the following attributes by default:
col_names_dict
: A dictionary holding the column names for eachstype
.feat_dict
: A dictionary holding theTensor
of differentstypes
. Forstype.numerical
andstype.categorical
, the shape ofTensor
is [num_rows, num_cols], while forstype.text_embedded
, the shape is [num_rows, num_cols, emb_dim].y
(optional): A tensor containing the target values for prediction.
Note
The set of keys in feat_dict
must exactly match with the set of keys in col_names_dict
.
TensorFrame
is validated at initialization time.
Creating a TensorFrame
from Dataset
is referred to as materialization.
materialize()
converts raw data frame in Dataset
into Tensors
and stores them in a TensorFrame
.
materialize()
also provides an optional argument path to cache the TensorFrame
and col_stats. If path is specified,
during the materialization PyTorch Frame will try to load saved TensorFrame
and col_stats at first. If there is no saved object found for that path, PyTorch Frame
will materialize the dataset and save the materialized TensorFrame
and col_stats to the path.
Note
Note that materialization does minimal processing of the original features, e.g., no normalization and missing value handling are performed.
PyTorch Frame converts missing values in categorical torch_frame.stype
to -1 and missing values in numerical torch_frame.stype
to NaN.
We expect NaN/missing-value handling and normalization to be handled by the model side via torch_frame.nn.encoder.StypeEncoder
.
The TensorFrame
object has Tensor
at its core; therefore, it’s friendly for training and inference with PyTorch. In PyTorch Frame, we build data loaders and models around TensorFrame
, benefitting from all the efficiency and flexibility from PyTorch.
from torch_frame import stype
dataset.materialize() # materialize the dataset
dataset.materialize(path='/tmp/titanic/data.pt') # materialize the dataset with caching
dataset.materialize(path='/tmp/titanic/data.pt') # next materialization will load the cache
tensor_frame = dataset.tensor_frame
tensor_frame.feat_dict.keys()
>>> dict_keys([<stype.categorical: 'categorical'>, <stype.numerical: 'numerical'>])
tensor_frame.feat_dict[stype.numerical]
>>> tensor([[22.0000, 1.0000, 0.0000, 7.2500],
[38.0000, 1.0000, 0.0000, 71.2833],
[26.0000, 0.0000, 0.0000, 7.9250],
...,
[ nan, 1.0000, 2.0000, 23.4500],
[26.0000, 0.0000, 0.0000, 30.0000],
[32.0000, 0.0000, 0.0000, 7.7500]])
tensor_frame.feat_dict[stype.categorical]
>>> tensor([[0, 0, 0],
[1, 1, 1],
[0, 1, 0],
...,
[0, 1, 0],
[1, 0, 1],
[0, 0, 2]])
tensor_frame.col_names_dict
>>> {<stype.categorical: 'categorical'>: ['Pclass', 'Sex', 'Embarked'], <stype.numerical: 'numerical'>: ['Age', 'SibSp', 'Parch', 'Fare']}
tensor_frame.y
>>> tensor([0, 1, 1, ..., 0, 1, 0])
A TensorFrame
contains the following basic properties:
tensor_frame.stypes
>>> [<stype.numerical: 'numerical'>, <stype.categorical: 'categorical'>]
tensor_frame.num_cols
>>> 7
tensor_frame.num_rows
>>> 891
tensor_frame.device
>>> device(type='cpu')
We support transferring the data in a TensorFrame
to devices supported by PyTorch.
tensor_frame.to("cpu")
tensor_frame.to("cuda")
Once a Dataset
is materialized, we can retrieve column statistics on the data.
For each stype
, a different set of statistics is calculated.
For categorical features,
StatType.COUNT
contains a tuple of two lists, where first list contains ordered category names and the second list contains category count, sorted from high to low.
For numerical features,
StatType.MEAN
denotes the mean value of the numerical feature,StatType.STD
denotes the standard deviation,StatType.QUANTILES
contains a list containing minimum value, first quartile (25th percentile), median (50th percentile), third quartile (75th percentile) and maximum value of the column.
dataset.col_to_stype
>>> {'Survived': <stype.categorical: 'categorical'>, 'Pclass': <stype.categorical: 'categorical'>, 'Sex': <stype.categorical: 'categorical'>, 'Age': <stype.numerical: 'numerical'>, 'SibSp': <stype.numerical: 'numerical'>, 'Parch': <stype.numerical: 'numerical'>, 'Fare': <stype.numerical: 'numerical'>, 'Embarked': <stype.categorical: 'categorical'>}
dataset.col_stats['Sex']
>>> {<StatType.COUNT: 'COUNT'>: (['male', 'female'], [577, 314])}
dataset.col_stats['Age']
>>> {<StatType.MEAN: 'MEAN'>: 29.69911764705882, <StatType.STD: 'STD'>: 14.516321150817316, <StatType.QUANTILES: 'QUANTILES'>: [0.42, 20.125, 28.0, 38.0, 80.0]}
Now let’s say you have a new pandas.DataFrame
called new_df
, and
you want to convert it to a corresponding TensorFrame
object.
You can achieve this as follows:
new_tf = dataset.convert_to_tensor_frame(new_df)
Mini-batches
Neural networks are usually trained in a mini-batch fashion. PyTorch Frame contains its own DataLoader
, which can load Dataset
or TensorFrame
in mini batches.
from torch_frame.data import DataLoader
data_loader = DataLoader(tensor_frame, batch_size=32,
shuffle=True)
for batch in data_loader:
batch
>>> TensorFrame(
num_cols=7,
num_rows=32,
categorical (3): ['Pclass', 'Sex', 'Embarked'],
numerical (4): ['Age', 'SibSp', 'Parch', 'Fare'],
has_target=True,
device='cpu',
)
Learning Methods on Tabular Data
After learning about data handling, datasets, and loader in PyTorch Frame, it’s time to implement our first model!
Now let’s implement a model called ExampleTransformer
. It uses TabTransformerConv
as its convolution layer.
Initializing a StypeWiseFeatureEncoder
requires col_stats
and col_names_dict
, we can directly get them as properties of any materialized dataset.
from typing import Any, Dict, List
from torch import Tensor
from torch.nn import Linear, Module, ModuleList
import torch_frame
from torch_frame import TensorFrame, stype
from torch_frame.data.stats import StatType
from torch_frame.nn.conv import TabTransformerConv
from torch_frame.nn.encoder import (
EmbeddingEncoder,
LinearEncoder,
StypeWiseFeatureEncoder,
)
class ExampleTransformer(Module):
def __init__(
self,
channels: int,
out_channels: int,
num_layers: int,
num_heads: int,
col_stats: Dict[str, Dict[StatType, Any]],
col_names_dict: Dict[torch_frame.stype, List[str]],
):
super().__init__()
self.encoder = StypeWiseFeatureEncoder(
out_channels=channels,
col_stats=col_stats,
col_names_dict=col_names_dict,
stype_encoder_dict={
stype.categorical: EmbeddingEncoder(),
stype.numerical: LinearEncoder()
},
)
self.tab_transformer_convs = ModuleList([
TabTransformerConv(
channels=channels,
num_heads=num_heads,
) for _ in range(num_layers)
])
self.decoder = Linear(channels, out_channels)
def forward(self, tf: TensorFrame) -> Tensor:
x, _ = self.encoder(tf)
for tab_transformer_conv in self.tab_transformer_convs:
x = tab_transformer_conv(x)
out = self.decoder(x.mean(dim=1))
return out
In the example above, EmbeddingEncoder
is used to encode the categorical features and
LinearEncoder
is used to encode the numerical features.
The embeddings are then passed into layers of TabTransformerConv
.
Then the outputs are concatenated and fed into a torch.nn.Linear
decoder.
Let’s create train-test split and create data loaders.
from torch_frame.datasets import Yandex
from torch_frame.data import DataLoader
dataset = Yandex(root='/tmp/adult', name='adult')
dataset.materialize()
dataset.shuffle()
train_dataset, test_dataset = dataset[:0.8], dataset[0.80:]
train_loader = DataLoader(train_dataset.tensor_frame, batch_size=128,
shuffle=True)
test_loader = DataLoader(test_dataset.tensor_frame, batch_size=128)
Let’s train this model for 50 epochs:
import torch
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ExampleTransformer(
channels=32,
out_channels=dataset.num_classes,
num_layers=2,
num_heads=8,
col_stats=train_dataset.col_stats,
col_names_dict=train_dataset.tensor_frame.col_names_dict,
).to(device)
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(50):
for tf in train_loader:
tf = tf.to(device)
pred = model(tf)
loss = F.cross_entropy(pred, tf.y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Finally, we can evaluate our model on the test split:
model.eval()
correct = 0
for tf in test_loader:
tf = tf.to(device)
pred = model(tf)
pred_class = pred.argmax(dim=-1)
correct += (tf.y == pred_class).sum()
acc = int(correct) / len(test_dataset)
print(f'Accuracy: {acc:.4f}')
>>> Accuracy: 0.8447
This is all it takes to implement your first deep tabular network. Happy hacking!