Source code for torch_frame.data.multi_embedding_tensor

from __future__ import annotations

from collections.abc import Sequence

import torch
from torch import Tensor

from torch_frame.data.multi_tensor import _batched_arange, _MultiTensor


[docs]class MultiEmbeddingTensor(_MultiTensor): r"""A read-only PyTorch tensor-based data structure that stores :obj:`[num_rows, num_cols, *]`, where the size of last dimension can be different for different column. Note that the last dimension is the same within each column across rows while in :class:`MultiNestedTensor`, the last dimension can be different across both rows and columns. It supports various advanced indexing, including slicing and list indexing along both row and column. Args: num_rows (int): Number of rows. num_cols (int): Number of columns. values (torch.Tensor): The values :class:`torch.Tensor` of size :obj:`[num_rows, dim1+dim2+...+dimN]`. offset (torch.Tensor): The offset :class:`torch.Tensor` of size :obj:`[num_cols+1,]`. Example: >>> tensor_list = [ ... torch.tensor([[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]]), # emb col 0 ... torch.tensor([[0.6, 0.7], [0.8, 0.9]]), # emb col 1 ... torch.tensor([[1.], [1.1]]), # emb col 2 ... ] >>> met = MultiEmbeddingTensor.from_tensor_list(tensor_list) >>> met MultiEmbeddingTensor(num_rows=2, num_cols=3, device='cpu') >>> met.values tensor([[0.0000, 0.1000, 0.2000, 0.6000, 0.7000, 1.0000], [0.3000, 0.4000, 0.5000, 0.8000, 0.9000, 1.1000]]) >>> met.offset tensor([0, 3, 5, 6]) >>> met[0, 0] tensor([0.0000, 0.1000, 0.2000]) >>> met[1, 1] tensor([0.8000, 0.9000]) >>> met[0] # Row integer indexing MultiEmbeddingTensor(num_rows=1, num_cols=3, device='cpu') >>> met[:, 0] # Column integer indexing MultiEmbeddingTensor(num_rows=2, num_cols=1, device='cpu') >>> met[:, 0].values # Embedding of column 0 tensor([[0.0000, 0.1000, 0.2000], [0.3000, 0.4000, 0.5000]]) >>> met[:1] # Row slicing MultiEmbeddingTensor(num_rows=1, num_cols=3, device='cpu') >>> met[[0, 1, 0, 0]] # Row list indexing MultiEmbeddingTensor(num_rows=4, num_cols=3, device='cpu') """ def validate(self) -> None: if torch.compiler.is_compiling(): return assert self.offset[0] == 0 assert len(self.offset) == self.num_cols + 1 assert self.offset.ndim == 1 assert self.values.ndim == 2 or self.values.numel() == 0
[docs] @classmethod def from_tensor_list( cls, tensor_list: list[Tensor], ) -> MultiEmbeddingTensor: r"""Creates a :class:`MultiEmbeddingTensor` from a list of :class:`torch.Tensor`. Args: tensor_list (List[Tensor]): A list of tensors, where each tensor has the same number of rows and can have a different number of columns. Returns: MultiEmbeddingTensor: A :class:`MultiEmbeddingTensor` instance. """ assert isinstance(tensor_list, list) and len(tensor_list) > 0 num_rows = tensor_list[0].size(0) device = tensor_list[0].device for tensor in tensor_list: msg = "tensor_list must be a list of tensors." assert isinstance(tensor, torch.Tensor), msg msg = "tensor_list must be a list of 2D tensors." assert tensor.dim() == 2, msg msg = "num_rows must be the same across a list of input tensors." assert tensor.size(0) == num_rows, msg msg = "device must be the same across a list of input tensors." assert tensor.device == device, msg offset_list = [] accum_idx = 0 offset_list.append(accum_idx) for tensor in tensor_list: accum_idx += tensor.size(1) offset_list.append(accum_idx) num_cols = len(tensor_list) values = torch.cat(tensor_list, dim=1) offset = torch.tensor(offset_list, device=values.device) return cls(num_rows, num_cols, values, offset)
def _get_value(self, i: int, j: int) -> Tensor: r"""Get :obj:`(i, j)`-th :class:`Tensor` object. Args: i (int): The row integer index. j (int): The column integer index. """ i = self._normalize_index(i, dim=0) j = self._normalize_index(j, dim=1) return self.values[i, self.offset[j]:self.offset[j + 1]] def _row_narrow(self, start: int, length: int) -> MultiEmbeddingTensor: r"""Helper function called by :meth:`MultiEmbeddingTensor.narrow`.""" return MultiEmbeddingTensor( num_rows=length, num_cols=self.num_cols, values=self.values[start:start + length], offset=self.offset, ) def _col_narrow(self, start: int, length: int) -> MultiEmbeddingTensor: r"""Helper function called by :meth:`MultiEmbeddingTensor.narrow`.""" offset = self.offset[start:start + length + 1] - self.offset[start] col_offset_start = self.offset[start] col_offset_end = self.offset[start + length] return MultiEmbeddingTensor( num_rows=self.num_rows, num_cols=length, values=self.values[:, col_offset_start:col_offset_end], offset=offset, ) def _row_index_select(self, index: Tensor) -> MultiEmbeddingTensor: r"""Helper function called by :meth:`MultiEmbeddingTensor.index_select`. """ return MultiEmbeddingTensor( num_rows=index.size(0), num_cols=self.num_cols, values=self.values[index], offset=self.offset, ) def _col_index_select(self, index: Tensor) -> MultiEmbeddingTensor: r"""Helper function called by :meth:`MultiEmbeddingTensor.index_select`. """ if index.numel() == 0: return self._empty(dim=1) offset = torch.zeros( index.size(0) + 1, dtype=torch.long, device=self.device, ) col_dims = self.offset[1:] - self.offset[:-1] new_col_dims = col_dims[index] torch.cumsum(new_col_dims, dim=0, out=offset[1:]) batch, arange = _batched_arange(new_col_dims) value_index = self.offset[index][batch] + arange return MultiEmbeddingTensor( num_rows=self.num_rows, num_cols=index.size(0), values=self.values[:, value_index], offset=offset, ) def _single_index_select( self, index: int, dim: int, ) -> MultiEmbeddingTensor: r"""Helper function called by :meth:`MultiEmbeddingTensor.index_select`. """ index = self._normalize_index(index, dim=dim) if dim == 0: return MultiEmbeddingTensor( num_rows=1, num_cols=self.num_cols, values=self.values[index].view(1, -1), offset=self.offset, ) elif dim == 1: value_index = slice(self.offset[index], self.offset[index + 1]) values = self.values[:, value_index] offset = self.offset[[0, index + 1]] - self.offset[[0, index]] return MultiEmbeddingTensor( num_rows=self.num_rows, num_cols=1, values=values, offset=offset, ) raise AssertionError("Should not reach here.")
[docs] def fillna_col( self, col_index: int, fill_value: int | float | Tensor, ) -> None: values_index = slice(self.offset[col_index], self.offset[col_index + 1]) values_col = self.values[:, values_index] if self.values.is_floating_point(): values_col[torch.isnan(values_col)] = fill_value else: values_col[values_col == -1] = fill_value
def _empty(self, dim: int) -> MultiEmbeddingTensor: """Creates an empty :class:`MultiEmbeddingTensor`. Args: dim (int): The dimension to empty. Returns: MultiEmbeddingTensor: An empty :class:`MultiEmbeddingTensor`. Note that if :obj:`dim=0`, it will return with the original offset tensor. """ return MultiEmbeddingTensor( num_rows=0 if dim == 0 else self.num_rows, num_cols=0 if dim == 1 else self.num_cols, values=torch.tensor([], device=self.device, dtype=self.dtype), offset=torch.tensor([0], device=self.device, dtype=torch.long) if dim == 1 else self.offset, )
[docs] @staticmethod def cat( xs: Sequence[MultiEmbeddingTensor], dim: int = 0, ) -> MultiEmbeddingTensor: """Concatenates a sequence of :class:`MultiEmbeddingTensor` along the specified dimension. Args: xs (Sequence[MultiEmbeddingTensor]): A sequence of :class:`MultiEmbeddingTensor` to be concatenated. dim (int): The dimension to concatenate along. Returns: MultiEmbeddingTensor: Concatenated multi embedding tensor. """ if len(xs) == 0: raise RuntimeError("Cannot concatenate a sequence of length 0.") for x in xs: msg = "`xs` must be a list of MultiEmbeddingTensor." assert isinstance(x, MultiEmbeddingTensor), msg msg = ("device must be the same across a sequence of" " MultiEmbeddingTensor.") assert x.device == xs[0].device, msg dim = MultiEmbeddingTensor._normalize_dim(dim) if len(xs) == 1: return xs[0] if dim == 0: num_rows = sum(x.num_rows for x in xs) num_cols = xs[0].num_cols for x in xs[1:]: if x.num_cols != num_cols: raise RuntimeError( "num_cols must be the same across a list of input " "multi embedding tensors.") values = torch.cat([x.values for x in xs], dim=0) # NOTE: offset shares the same data with the input's offset, # which is inconsistent with when dim=1 offset = xs[0].offset return MultiEmbeddingTensor(num_rows, num_cols, values, offset) elif dim == 1: num_rows = xs[0].num_rows for x in xs[1:]: if x.num_rows != num_rows: raise RuntimeError( "num_rows must be the same across a list of input " "multi embedding tensors.") num_cols = sum(x.num_cols for x in xs) values = torch.cat([x.values for x in xs], dim=1) offset_list = [0] for x in xs: offset_list.extend(x.offset[1:] + offset_list[-1]) # NOTE: offset is a data copy of the input's offset, # which is inconsistent with when dim=0 offset = torch.tensor(offset_list) return MultiEmbeddingTensor(num_rows, num_cols, values, offset) raise AssertionError("Should not reach here.")