Source code for torch_frame.utils.memory

from typing import Any

from torch import Tensor

from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor
from torch_frame.data.multi_nested_tensor import MultiNestedTensor


[docs]def num_bytes(data: Any) -> int: r"""Returns the number of bytes the tensor data consumes. Args: data: The tensor data. """ if isinstance(data, Tensor): return data.element_size() * data.numel() if isinstance(data, MultiNestedTensor | MultiEmbeddingTensor): return num_bytes(data.values) + num_bytes(data.offset) if isinstance(data, list): return sum([num_bytes(value) for value in data]) if isinstance(data, dict): return sum([num_bytes(value) for value in data.values()]) raise NotImplementedError(f"'num_bytes' not implemented for " f"'{type(data)}'")