Source code for torch_frame.data.tensor_frame

from __future__ import annotations

import copy
from collections.abc import Callable
from typing import Any

import torch
from torch import Tensor

import torch_frame
from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor
from torch_frame.data.multi_nested_tensor import MultiNestedTensor
from torch_frame.data.multi_tensor import _MultiTensor
from torch_frame.typing import IndexSelectType, TensorData


[docs]class TensorFrame: r"""A tensor frame holds a :pytorch:`PyTorch` tensor for each table column. Table columns are organized into their semantic types :class:`~torch_frame.stype` (*e.g.*, categorical, numerical) and mapped to a compact tensor representation (*e.g.*, strings in a categorical column are mapped to indices from :obj:`{0, ..., num_categories - 1}`), and can be accessed through :obj:`feat_dict`. For instance, :obj:`feat_dict[stype.numerical]` stores a concatenated :pytorch:`PyTorch` tensor for all numerical features, where the first and second dimension represents the row and column in the original data frame, respectively. :class:`TensorFrame` handles missing values via :obj:`float('NaN')` for floating-point tensors, and :obj:`-1` otherwise. :obj:`col_names_dict` maps each column in :obj:`feat_dict` to their original column name. For example, :obj:`col_names_dict[stype.numerical][i]` stores the column name of :obj:`feat_dict[stype.numerical][:, i]`. Additionally, :class:`TensorFrame` can store any target values in :obj:`y`. .. code-block:: python import torch_frame tf = torch_frame.TensorFrame( feat_dict = { # Two numerical columns: torch_frame.numerical: torch.randn(10, 2), # Three categorical columns: torch_frame.categorical: torch.randint(0, 5, (10, 3)), }, col_names_dict = { torch_frame.numerical: ['num_1', 'num_2'], torch_frame.categorical: ['cat_1', 'cat_2', 'cat_3'], }, ) print(len(tf)) >>> 10 # Row-wise filtering: tf = tf[torch.tensor([0, 2, 4, 6, 8])] print(len(tf)) >>> 5 # Transfer tensor frame to the GPU: tf = tf.to('cuda') """ def __init__( self, feat_dict: dict[torch_frame.stype, TensorData], col_names_dict: dict[torch_frame.stype, list[str]], y: Tensor | None = None, num_rows: int | None = None, ) -> None: self.feat_dict = feat_dict self.col_names_dict = col_names_dict self.y = y self._num_rows = num_rows self.validate() # Quick mapping from column names into their (stype, idx) pairs in # col_names_dict. Used for fast get_col_feat. self._col_to_stype_idx: dict[str, tuple[torch_frame.stype, int]] = {} for stype_name, cols in self.col_names_dict.items(): for idx, col in enumerate(cols): self._col_to_stype_idx[col] = (stype_name, idx)
[docs] def validate(self) -> None: r"""Validates the :class:`TensorFrame` object.""" if self.feat_dict.keys() != self.col_names_dict.keys(): raise ValueError( f"The keys of feat_dict and col_names_dict must be the same, " f"but got {self.feat_dict.keys()} for feat_dict and " f"{self.col_names_dict.keys()} for col_names_dict.") num_rows = self.num_rows empty_stypes: list[torch_frame.stype] = [] for stype_name, feats in self.feat_dict.items(): col_names = self.col_names_dict[stype_name] if not isinstance(col_names, list): raise ValueError( f"col_names_dict[{stype_name}] must be a list of column " f"names.") num_cols = len(col_names) if num_cols == 0: empty_stypes.append(stype_name) tensors: list[(Tensor | MultiNestedTensor | MultiEmbeddingTensor)] if isinstance(feats, dict): tensors = [feat for feat in feats.values()] else: tensors = [feats] for tensor in tensors: if tensor.dim() < 2: raise ValueError(f"feat_dict['{stype_name}'] must be at " f"least 2-dimensional") if num_cols != tensor.size(1): raise ValueError( f"The expected number of columns for {stype_name} " f"feature is {num_cols}, which does not align with " f"the column dimensionality of " f"feat_dict[{stype_name}] (got {tensor.size(1)})") if tensor.size(0) != num_rows: raise ValueError( f"The length of elements in feat_dict are " f"not aligned, got {tensor.size(0)} but " f"expected {num_rows}.") if len(empty_stypes) > 0: raise RuntimeError( f"Empty columns for the following stypes: {empty_stypes}." f"Please manually delete the above stypes.") if self.y is not None: if len(self.y) != num_rows: raise ValueError( f"The length of y is {len(self.y)}, which is not aligned " f"with the number of rows ({num_rows}).")
[docs] def get_col_feat( self, col_name: str, *, return_stype: bool = False, ) -> TensorData | tuple[TensorData, torch_frame.stype]: r"""Get feature of a given column. Args: col_name (str): Input column name. return_stype (bool, optional): If set to :obj:`True`, will additionally return the semantic type of the column. Returns: TensorData: Column feature for the given :obj:`col_name`. The shape is :obj:`[num_rows, 1, *]`. """ if col_name not in self._col_to_stype_idx: raise ValueError(f"'{col_name}' is not available in the " f"'{self.__class__.__name__}' object") stype_name, idx = self._col_to_stype_idx[col_name] feat = self.feat_dict[stype_name] if isinstance(feat, dict): col_feat: dict[str, MultiNestedTensor] = {} for key, mnt in feat.items(): value = mnt[:, idx] assert isinstance(value, MultiNestedTensor) col_feat[key] = value out = col_feat elif isinstance(feat, _MultiTensor): out = feat[:, idx] else: assert isinstance(feat, Tensor) out = feat[:, idx].unsqueeze(1) return (out, stype_name) if return_stype else out
@property def stypes(self) -> list[torch_frame.stype]: r"""Returns a canonical ordering of stypes in :obj:`feat_dict`.""" return list( filter(lambda x: x in self.feat_dict, list(torch_frame.stype))) @property def num_cols(self) -> int: r"""The number of columns in the :class:`TensorFrame`.""" return sum( len(col_names) for col_names in self.col_names_dict.values()) @property def num_rows(self) -> int: r"""The number of rows in the :class:`TensorFrame`.""" if self._num_rows is not None: return self._num_rows if self.is_empty: return 0 feat = next(iter(self.feat_dict.values())) if isinstance(feat, dict): return len(next(iter(feat.values()))) return len(feat) @property def device(self) -> torch.device | None: r"""The device of the :class:`TensorFrame`.""" if self.is_empty: return None feat = next(iter(self.feat_dict.values())) if isinstance(feat, dict): return next(iter(feat.values())).device return feat.device @property def is_empty(self) -> bool: r"""Returns :obj:`True` if the :class:`TensorFrame` is empty.""" return len(self.feat_dict) == 0 # Python Built-ins ######################################################## def __len__(self) -> int: return self.num_rows def __eq__(self, other: Any) -> bool: # Match instance type if not isinstance(other, TensorFrame): return False # Match length if len(self) != len(other): return False # Match target if self.y is not None: if other.y is None: return False elif not torch.allclose(other.y, self.y): return False else: if other.y is not None: return False # Match col_names_dict if self.col_names_dict != other.col_names_dict: return False # Match feat_dict for stype_name, self_feat in self.feat_dict.items(): other_feat = other.feat_dict[stype_name] if isinstance(self_feat, Tensor): if not isinstance(other_feat, Tensor): return False if self_feat.shape != other_feat.shape: return False if not torch.allclose(self_feat, other_feat, equal_nan=True): return False elif isinstance(self_feat, MultiNestedTensor): if not isinstance(other_feat, MultiNestedTensor): return False if not MultiNestedTensor.allclose(self_feat, other_feat, equal_nan=True): return False elif isinstance(self_feat, MultiEmbeddingTensor): if not isinstance(other_feat, MultiEmbeddingTensor): return False if not MultiEmbeddingTensor.allclose(self_feat, other_feat, equal_nan=True): return False elif isinstance(self_feat, dict): if not isinstance(other_feat, dict): return False if self_feat.keys() != other_feat.keys(): return False for feat_name in self_feat.keys(): if not MultiNestedTensor.allclose( self_feat[feat_name], other_feat[feat_name], equal_nan=True, ): return False return True def __neq__(self, other: Any) -> bool: return not self.__eq__(other) def __repr__(self) -> str: stype_repr: str if self.is_empty: stype_repr = "" device_repr = " device=None,\n" else: stype_repr = "\n".join([ f" {stype} ({len(col_names)}): {col_names}," for stype, col_names in self.col_names_dict.items() ]) stype_repr += "\n" device_repr = f" device='{self.device}',\n" return (f"{self.__class__.__name__}(\n" f" num_cols={self.num_cols},\n" f" num_rows={self.num_rows},\n" f"{stype_repr}" f" has_target={self.y is not None},\n" f"{device_repr}" f")") def __getitem__(self, index: IndexSelectType) -> TensorFrame: if isinstance(index, int): index = [index] def fn(x): if isinstance(x, dict): y = {} for key in x: y[key] = x[key][index] else: return x[index] return y out = self._apply(fn) if self._num_rows is not None: device = index.device if isinstance(index, Tensor) else 'cpu' dummy = torch.empty((self.num_rows, 0), device=device) out._num_rows = dummy[index].size(0) return out def __copy__(self) -> TensorFrame: out = self.__class__.__new__(self.__class__) for key, value in self.__dict__.items(): out.__dict__[key] = value out.feat_dict = copy.copy(out.feat_dict) out.col_names_dict = copy.copy(out.col_names_dict) return out # Device Transfer ######################################################### def to(self, *args, **kwargs): def fn(x): if isinstance(x, dict): for key in x: x[key] = x[key].to(*args, **kwargs) else: x = x.to(*args, **kwargs) return x return self._apply(fn) def cpu(self, *args, **kwargs): def fn(x): if isinstance(x, dict): for key in x: x[key] = x[key].cpu(*args, **kwargs) else: x = x.cpu(*args, **kwargs) return x return self._apply(fn) def cuda(self, *args, **kwargs): def fn(x): if isinstance(x, dict): for key in x: x[key] = x[key].cuda(*args, **kwargs) else: x = x.cuda(*args, **kwargs) return x return self._apply(fn) def pin_memory(self, *args, **kwargs): def fn(x): if isinstance(x, dict): for key in x: x[key] = x[key].pin_memory(*args, **kwargs) else: x = x.pin_memory(*args, **kwargs) return x return self._apply(fn) # Helper Functions ######################################################## def _apply(self, fn: Callable[[TensorData], TensorData]) -> TensorFrame: out = copy.copy(self) out.feat_dict = {stype: fn(x) for stype, x in out.feat_dict.items()} if out.y is not None: y = fn(out.y) assert isinstance(y, Tensor | MultiNestedTensor) out.y = y return out