torch_frame.config.ModelConfig

class ModelConfig(model: Callable[[torch.Tensor | torch_frame.data.multi_nested_tensor.MultiNestedTensor | torch_frame.data.multi_embedding_tensor.MultiEmbeddingTensor | dict[str, torch_frame.data.multi_nested_tensor.MultiNestedTensor]], Tensor], out_channels: int)[source]

Bases: object

Learnable model that maps a single-column TensorData object into row embeddings.

Parameters:
  • model (callable) – A callable model that takes a TensorData object of shape [batch_size, 1, *] as input and outputs embeddings of shape [batch_size, 1, out_channels].

  • out_channels (int) – Model output channels.