Source code for torch_frame.nn.models.mlp

from __future__ import annotations

from typing import Any

import torch
from torch import Tensor
from torch.nn import (
    BatchNorm1d,
    Dropout,
    LayerNorm,
    Linear,
    Module,
    ReLU,
    Sequential,
)

import torch_frame
from torch_frame import TensorFrame, stype
from torch_frame.data.stats import StatType
from torch_frame.nn.encoder.stype_encoder import (
    EmbeddingEncoder,
    LinearEncoder,
    StypeEncoder,
)
from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder


[docs]class MLP(Module): r"""The light-weight MLP model that mean-pools column embeddings and applies MLP over it. Args: channels (int): The number of channels in the backbone layers. out_channels (int): The number of output channels in the decoder. num_layers (int): The number of layers in the backbone. col_stats(dict[str,Dict[:class:`torch_frame.data.stats.StatType`,Any]]): A dictionary that maps column name into stats. Available as :obj:`dataset.col_stats`. col_names_dict (dict[:class:`torch_frame.stype`, List[str]]): A dictionary that maps stype to a list of column names. The column names are sorted based on the ordering that appear in :obj:`tensor_frame.feat_dict`. Available as :obj:`tensor_frame.col_names_dict`. stype_encoder_dict (dict[:class:`torch_frame.stype`, :class:`torch_frame.nn.encoder.StypeEncoder`], optional): A dictionary mapping stypes into their stype encoders. (default: :obj:`None`, will call :obj:`EmbeddingEncoder()` for categorical feature and :obj:`LinearEncoder()` for numerical feature) normalization (str, optional): The type of normalization to use. :obj:`batch_norm`, :obj:`layer_norm`, or :obj:`None`. (default: :obj:`layer_norm`) dropout_prob (float): The dropout probability (default: `0.2`). """ def __init__( self, channels: int, out_channels: int, num_layers: int, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dict: dict[torch_frame.stype, StypeEncoder] | None = None, normalization: str | None = "layer_norm", dropout_prob: float = 0.2, ) -> None: super().__init__() if stype_encoder_dict is None: stype_encoder_dict = { stype.categorical: EmbeddingEncoder(), stype.numerical: LinearEncoder(), } self.encoder = StypeWiseFeatureEncoder( out_channels=channels, col_stats=col_stats, col_names_dict=col_names_dict, stype_encoder_dict=stype_encoder_dict, ) self.mlp = Sequential() for _ in range(num_layers - 1): self.mlp.append(Linear(channels, channels)) if normalization == "layer_norm": self.mlp.append(LayerNorm(channels)) elif normalization == "batch_norm": self.mlp.append(BatchNorm1d(channels)) self.mlp.append(ReLU()) self.mlp.append(Dropout(p=dropout_prob)) self.mlp.append(Linear(channels, out_channels)) self.reset_parameters() def reset_parameters(self) -> None: self.encoder.reset_parameters() for param in self.mlp: if hasattr(param, 'reset_parameters'): param.reset_parameters()
[docs] def forward(self, tf: TensorFrame) -> Tensor: r"""Transforming :class:`TensorFrame` object into output prediction. Args: tf (TensorFrame): Input :class:`TensorFrame` object. Returns: torch.Tensor: Output of shape [batch_size, out_channels]. """ x, _ = self.encoder(tf) x = torch.mean(x, dim=1) out = self.mlp(x) return out