torch_frame.config.ModelConfig
- class ModelConfig(model: Callable[[torch.Tensor | torch_frame.data.multi_nested_tensor.MultiNestedTensor | torch_frame.data.multi_embedding_tensor.MultiEmbeddingTensor | dict[str, torch_frame.data.multi_nested_tensor.MultiNestedTensor]], Tensor], out_channels: int)[source]
Bases:
objectLearnable model that maps a single-column
TensorDataobject into row embeddings.- Parameters:
model (callable) – A callable model that takes a
TensorDataobject of shape[batch_size, 1, *]as input and outputs embeddings of shape[batch_size, 1, out_channels].out_channels (int) – Model output channels.