torch_frame.data.TensorFrame

class TensorFrame(feat_dict: dict[torch_frame.stype, TensorData], col_names_dict: dict[torch_frame.stype, list[str]], y: Tensor | None = None, num_rows: int | None = None)[source]

Bases: object

A tensor frame holds a tensor for each table column. Table columns are organized into their semantic types stype (e.g., categorical, numerical) and mapped to a compact tensor representation (e.g., strings in a categorical column are mapped to indices from {0, ..., num_categories - 1}), and can be accessed through feat_dict. For instance, feat_dict[stype.numerical] stores a concatenated tensor for all numerical features, where the first and second dimension represents the row and column in the original data frame, respectively.

TensorFrame handles missing values via float('NaN') for floating-point tensors, and -1 otherwise.

col_names_dict maps each column in feat_dict to their original column name. For example, col_names_dict[stype.numerical][i] stores the column name of feat_dict[stype.numerical][:, i].

Additionally, TensorFrame can store any target values in y.

import torch_frame

tf = torch_frame.TensorFrame(
    feat_dict = {
        # Two numerical columns:
        torch_frame.numerical: torch.randn(10, 2),
        # Three categorical columns:
        torch_frame.categorical: torch.randint(0, 5, (10, 3)),
    },
    col_names_dict = {
        torch_frame.numerical: ['num_1', 'num_2'],
        torch_frame.categorical: ['cat_1', 'cat_2', 'cat_3'],

    },
)

print(len(tf))
>>> 10

# Row-wise filtering:
tf = tf[torch.tensor([0, 2, 4, 6, 8])]
print(len(tf))
>>> 5

# Transfer tensor frame to the GPU:
tf = tf.to('cuda')
validate() None[source]

Validates the TensorFrame object.

get_col_feat(col_name: str, *, return_stype: bool = False) TensorData | tuple[TensorData, torch_frame.stype][source]

Get feature of a given column.

Parameters:
  • col_name (str) – Input column name.

  • return_stype (bool, optional) – If set to True, will additionally return the semantic type of the column.

Returns:

Column feature for the given col_name. The shape

is [num_rows, 1, *].

Return type:

TensorData

property stypes: list[torch_frame._stype.stype]

Returns a canonical ordering of stypes in feat_dict.

property num_cols: int

The number of columns in the TensorFrame.

property num_rows: int

The number of rows in the TensorFrame.

property device: torch.device | None

The device of the TensorFrame.

property is_empty: bool

Returns True if the TensorFrame is empty.