Source code for torch_frame.config.model

from collections.abc import Callable
from dataclasses import dataclass

from torch import Tensor

from torch_frame.typing import TensorData


[docs]@dataclass class ModelConfig: r"""Learnable model that maps a single-column :class:`TensorData` object into row embeddings. Args: model (callable): A callable model that takes a :obj:`TensorData` object of shape :obj:`[batch_size, 1, *]` as input and outputs embeddings of shape :obj:`[batch_size, 1, out_channels]`. out_channels (int): Model output channels. """ model: Callable[[TensorData], Tensor] out_channels: int