torch_frame.data.MultiEmbeddingTensor
- class MultiEmbeddingTensor(num_rows: int, num_cols: int, values: Tensor, offset: Tensor)[source]
Bases:
_MultiTensor
A read-only PyTorch tensor-based data structure that stores
[num_rows, num_cols, *]
, where the size of last dimension can be different for different column. Note that the last dimension is the same within each column across rows while inMultiNestedTensor
, the last dimension can be different across both rows and columns. It supports various advanced indexing, including slicing and list indexing along both row and column.- Parameters:
num_rows (int) – Number of rows.
num_cols (int) – Number of columns.
values (torch.Tensor) – The values
torch.Tensor
of size[num_rows, dim1+dim2+...+dimN]
.offset (torch.Tensor) – The offset
torch.Tensor
of size[num_cols+1,]
.
Example
>>> tensor_list = [ ... torch.tensor([[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]]), # emb col 0 ... torch.tensor([[0.6, 0.7], [0.8, 0.9]]), # emb col 1 ... torch.tensor([[1.], [1.1]]), # emb col 2 ... ] >>> met = MultiEmbeddingTensor.from_tensor_list(tensor_list) >>> met MultiEmbeddingTensor(num_rows=2, num_cols=3, device='cpu') >>> met.values tensor([[0.0000, 0.1000, 0.2000, 0.6000, 0.7000, 1.0000], [0.3000, 0.4000, 0.5000, 0.8000, 0.9000, 1.1000]]) >>> met.offset tensor([0, 3, 5, 6]) >>> met[0, 0] tensor([0.0000, 0.1000, 0.2000]) >>> met[1, 1] tensor([0.8000, 0.9000]) >>> met[0] # Row integer indexing MultiEmbeddingTensor(num_rows=1, num_cols=3, device='cpu') >>> met[:, 0] # Column integer indexing MultiEmbeddingTensor(num_rows=2, num_cols=1, device='cpu') >>> met[:, 0].values # Embedding of column 0 tensor([[0.0000, 0.1000, 0.2000], [0.3000, 0.4000, 0.5000]]) >>> met[:1] # Row slicing MultiEmbeddingTensor(num_rows=1, num_cols=3, device='cpu') >>> met[[0, 1, 0, 0]] # Row list indexing MultiEmbeddingTensor(num_rows=4, num_cols=3, device='cpu')
- classmethod from_tensor_list(tensor_list: list[torch.Tensor]) MultiEmbeddingTensor [source]
Creates a
MultiEmbeddingTensor
from a list oftorch.Tensor
.- Parameters:
tensor_list (List[Tensor]) – A list of tensors, where each tensor has the same number of rows and can have a different number of columns.
- Returns:
A
MultiEmbeddingTensor
instance.- Return type:
- fillna_col(col_index: int, fill_value: int | float | Tensor) None [source]
Fill the
index
-th column inMultiTensor
with fill_value in-place.
- static cat(xs: Sequence[MultiEmbeddingTensor], dim: int = 0) MultiEmbeddingTensor [source]
Concatenates a sequence of
MultiEmbeddingTensor
along the specified dimension.- Parameters:
xs (Sequence[MultiEmbeddingTensor]) – A sequence of
MultiEmbeddingTensor
to be concatenated.dim (int) – The dimension to concatenate along.
- Returns:
Concatenated multi embedding tensor.
- Return type: