Source code for torch_frame.utils.io

from __future__ import annotations

import pickle
import re
import warnings
from typing import Any

import torch
from torch import Tensor

import torch_frame
from torch_frame.data import (
    MultiEmbeddingTensor,
    MultiNestedTensor,
    TensorFrame,
)
from torch_frame.data.multi_tensor import _MultiTensor
from torch_frame.data.stats import StatType
from torch_frame.typing import TensorData


def serialize_feat_dict(
    feat_dict: dict[torch_frame.stype, TensorData]
) -> dict[torch_frame.stype, Any]:
    feat_serialized_dict = {}
    for stype, feat in feat_dict.items():
        feat_serialized: Any
        if stype.use_multi_tensor:
            assert isinstance(feat, _MultiTensor)
            feat_serialized = feat.to_dict()
        elif stype.use_dict_multi_nested_tensor:
            feat_serialized = {}
            assert isinstance(feat, dict)
            for name, f in feat.items():
                assert isinstance(f, MultiNestedTensor)
                feat_serialized[name] = f.to_dict()
        else:
            assert isinstance(feat, Tensor)
            feat_serialized = feat
        feat_serialized_dict[stype] = feat_serialized
    return feat_serialized_dict


def deserialize_feat_dict(
    feat_serialized_dict: dict[torch_frame.stype, Any]
) -> dict[torch_frame.stype, TensorData]:
    feat_dict = {}
    for stype, feat_serialized in feat_serialized_dict.items():
        feat: TensorData
        if stype.use_multi_nested_tensor:
            feat = MultiNestedTensor(**feat_serialized)
        elif stype.use_multi_embedding_tensor:
            feat = MultiEmbeddingTensor(**feat_serialized)
        elif stype.use_dict_multi_nested_tensor:
            feat = {}
            for name, f_serialized in feat_serialized.items():
                feat[name] = MultiNestedTensor(**f_serialized)
        else:
            assert isinstance(feat_serialized, Tensor)
            feat = feat_serialized
        feat_dict[stype] = feat
    return feat_dict


[docs]def save(tensor_frame: TensorFrame, col_stats: dict[str, dict[StatType, Any]] | None, path: str): r"""Save a :class:`TensorFrame` object and optional :obj:`col_stats` to specified path. Args: tensor_frame(TensorFrame): The :class:`TensorFrame` object to be saved. col_stats (dict, optional): The :obj:`col_stats` to be saved. path (str): Path to save the :class:`TensorFrame` object and :obj:`col_stats`. """ tf_dict = { 'y': tensor_frame.y, 'col_names_dict': tensor_frame.col_names_dict, 'feat_serialized_dict': serialize_feat_dict(tensor_frame.feat_dict), } torch.save((tf_dict, col_stats), path)
[docs]def load( path: str, device: torch.device | None = None, ) -> tuple[TensorFrame, dict[str, dict[StatType, Any]] | None]: r"""Load saved :class:`TensorFrame` object and optional :obj:`col_stats` from a specified path. Args: path (str): Path to load the :class:`TensorFrame` object and :obj:`col_stats`. device (torch.device, optional): Device to load the :class:`TensorFrame` object. (default: :obj:`None`) Returns: tuple: A tuple of loaded :class:`TensorFrame` object and optional :obj:`col_stats`. """ if torch_frame.typing.WITH_PT24: try: tf_dict, col_stats = torch.load(path, weights_only=True) except pickle.UnpicklingError as e: error_msg = str(e) if "add_safe_globals" in error_msg: warn_msg = ("Weights only load failed. Please file an issue " "to make `torch.load(weights_only=True)` " "compatible in your case.") match = re.search(r'add_safe_globals\(.*?\)', error_msg) if match is not None: warnings.warn( f"{warn_msg} Please use " f"`torch.serialization.{match.group()}` to " f"allowlist this global.", stacklevel=2) else: warnings.warn(warn_msg, stacklevel=2) tf_dict, col_stats = torch.load(path, weights_only=False) else: raise e else: tf_dict, col_stats = torch.load(path, weights_only=False) tf_dict['feat_dict'] = deserialize_feat_dict( tf_dict.pop('feat_serialized_dict')) tensor_frame = TensorFrame(**tf_dict).to(device) return tensor_frame, col_stats