torch_frame.data.MultiEmbeddingTensor
- class MultiEmbeddingTensor(num_rows: int, num_cols: int, values: Tensor, offset: Tensor)[source]
Bases:
_MultiTensorA read-only PyTorch tensor-based data structure that stores
[num_rows, num_cols, *], where the size of last dimension can be different for different column. Note that the last dimension is the same within each column across rows while inMultiNestedTensor, the last dimension can be different across both rows and columns. It supports various advanced indexing, including slicing and list indexing along both row and column.- Parameters:
num_rows (int) – Number of rows.
num_cols (int) – Number of columns.
values (torch.Tensor) – The values
torch.Tensorof size[num_rows, dim1+dim2+...+dimN].offset (torch.Tensor) – The offset
torch.Tensorof size[num_cols+1,].
Example
>>> tensor_list = [ ... torch.tensor([[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]]), # emb col 0 ... torch.tensor([[0.6, 0.7], [0.8, 0.9]]), # emb col 1 ... torch.tensor([[1.], [1.1]]), # emb col 2 ... ] >>> met = MultiEmbeddingTensor.from_tensor_list(tensor_list) >>> met MultiEmbeddingTensor(num_rows=2, num_cols=3, device='cpu') >>> met.values tensor([[0.0000, 0.1000, 0.2000, 0.6000, 0.7000, 1.0000], [0.3000, 0.4000, 0.5000, 0.8000, 0.9000, 1.1000]]) >>> met.offset tensor([0, 3, 5, 6]) >>> met[0, 0] tensor([0.0000, 0.1000, 0.2000]) >>> met[1, 1] tensor([0.8000, 0.9000]) >>> met[0] # Row integer indexing MultiEmbeddingTensor(num_rows=1, num_cols=3, device='cpu') >>> met[:, 0] # Column integer indexing MultiEmbeddingTensor(num_rows=2, num_cols=1, device='cpu') >>> met[:, 0].values # Embedding of column 0 tensor([[0.0000, 0.1000, 0.2000], [0.3000, 0.4000, 0.5000]]) >>> met[:1] # Row slicing MultiEmbeddingTensor(num_rows=1, num_cols=3, device='cpu') >>> met[[0, 1, 0, 0]] # Row list indexing MultiEmbeddingTensor(num_rows=4, num_cols=3, device='cpu')
- classmethod from_tensor_list(tensor_list: list[torch.Tensor]) MultiEmbeddingTensor[source]
Creates a
MultiEmbeddingTensorfrom a list oftorch.Tensor.- Parameters:
tensor_list (List[Tensor]) – A list of tensors, where each tensor has the same number of rows and can have a different number of columns.
- Returns:
A
MultiEmbeddingTensorinstance.- Return type:
- fillna_col(col_index: int, fill_value: int | float | torch.Tensor) None[source]
Fill the
index-th column inMultiTensorwith fill_value in-place.
- static cat(xs: Sequence[MultiEmbeddingTensor], dim: int = 0) MultiEmbeddingTensor[source]
Concatenates a sequence of
MultiEmbeddingTensoralong the specified dimension.- Parameters:
xs (Sequence[MultiEmbeddingTensor]) – A sequence of
MultiEmbeddingTensorto be concatenated.dim (int) – The dimension to concatenate along.
- Returns:
Concatenated multi embedding tensor.
- Return type: