Handling Text Columns

handles text columns by utilizing text embedding models, which can be pre-trained language models. We support two major options of utilizing text embedding models:

  1. To pe-encode texts into embeddings at the materialization stage (so that the model parameters are frozen during training stage)

  2. To generate text embeddings during the training stage and finetune their model parameters.

These options have trade-off. The option (1) allows faster training, while option (2) allows more accurate prediction but with more costly training due to fine-tuning into the text models. In , we can specify which option to use for each text column by simply specifying its stype: In col_to_stype argument passed to Dataset, we can specify stype.text_embedded for columns we want to use option (1) and stype.text_tokenized for columns we use option (2). Let’s use a real-world dataset to learn how to achieve this.

Handling Text Columns in a Real-World Dataset

provides a collection of tabular benchmark datasets with text columns, such as MultimodalTextBenchmark.

As we briefly discussed, provides two semantic types for text columns:

1. stype.text_embedded will pre-encode texts using user-specified text embedding models at the dataset materialization stage.

2. stype.text_tokenized will tokenize texts using user-specified text tokenizers at the dataset materialization stage. The tokenized texts (sequences of integers) are fed into text models at training stage, and the parameters of the text models are fine-tuned.

The processes of initializing and materializing datasets are similar to Introduction by Example. Below we highlight the difference for each semantic type.

Pre-encode texts into embeddings

For stype.text_embedded, first you need to specify the text embedding models. Here, we use the SentenceTransformer package.

pip install -U sentence-transformers

Specifying Text Embedders

Next, we create a text encoder class that encodes a list of strings into text embeddings in a mini-batch manner.

from typing import List
import torch
from torch import Tensor
from sentence_transformers import SentenceTransformer

class TextToEmbedding:
    def __init__(self, device: torch.device):
        self.model = SentenceTransformer('all-distilroberta-v1', device=device)

    def __call__(self, sentences: List[str]) -> Tensor:
        # Encode a list of batch_size sentences into a PyTorch Tensor of
        # size [batch_size, emb_dim]
        embeddings = self.model.encode(
            sentences,
            convert_to_numpy=False,
            convert_to_tensor=True,
        )
        return embeddings.cpu()

Then we instantiate TextEmbedderConfig that specifies the text_embedder and batch_size we use to pre-encode the texts using the text_embedder.

from torch_frame.config.text_embedder import TextEmbedderConfig

device = (torch.device('cuda' if torch.cuda.is_available() else 'cpu')

col_to_text_embedder_cfg = TextEmbedderConfig(
    text_embedder=TextToEmbedding(device),
    batch_size=8,
)

Note that Transformer-based text embedding models are often GPU memory intensive, so it is important to specify a reasonable batch_size (e.g., 8). Also, note that we will use the same TextEmbedderConfig across all text columns by default. If we want to use different text_embedder for different text columns (let’s say "text_col0" and "text_col1"), we can use a dictionary as follows:

# Prepare text_embedder0 and text_embedder1 for text_col0 and text_col1, respectively.
col_to_text_embedder_cfg = {
    "text_col0":
    TextEmbedderConfig(text_embedder=text_embedder0, batch_size=4),
    "text_col1":
    TextEmbedderConfig(text_embedder=text_embedder1, batch_size=8),
}

Embedding Text Columns for a Dataset

Once col_to_text_embedder_cfg is specified, we can pass it to Dataset object as follows.

import torch_frame
from torch_frame.datasets import MultimodalTextBenchmark

dataset = MultimodalTextBenchmark(
    root='/tmp/multimodal_text_benchmark/wine_reviews',
    name='wine_reviews',
    col_to_text_embedder_cfg=col_to_text_embedder_cfg,
)

dataset.feat_cols  # This dataset contains one text column `description`
>>> ['description', 'country', 'province', 'points', 'price']

dataset.col_to_stype['description']
>>> <stype.text_embedded: 'text_embedded'>

We then call dataset.materialize(path=...), which will use text embedding models to pre-encode text_embedded columns based on the given col_to_text_embedder_cfg.

# Pre-encode text columns based on col_to_text_embedder_cfg. This may take a while.
dataset.materialize(path='/tmp/multimodal_text_benchmark/wine_reviews/data.pt')

len(dataset)
>>> 105154

# Text embeddings are stored as MultiNestedTensor
dataset.tensor_frame.feat_dict[torch_frame.embedding]
>>> MultiNestedTensor(num_rows=105154, num_cols=1, device='cpu')

It is strongly recommended to specify the path during materialize(). It will cache generated TensorFrame, therefore, avoiding embedding texts in every materialization run, which can be quite time-consuming. Once cached, TensorFrame can be reused for subsequent materialize() calls.

Note

Internally, text_embedded is grouped together in the parent stype embedding within TensorFrame.

Fusing Text Embeddings into Tabular Learning

offers LinearEmbeddingEncoder designed for encoding embedding within TensorFrame. This module applies linear function over the pre-computed embeddings.

from torch_frame.nn.encoder import (
    EmbeddingEncoder,
    LinearEmbeddingEncoder,
    LinearEncoder,
)

stype_encoder_dict = {
    stype.categorical: EmbeddingEncoder(),
    stype.numerical: LinearEncoder(),
    stype.embedding: LinearEmbeddingEncoder()
}

Then, stype_encoder_dict can be directly fed into StypeWiseFeatureEncoder.

Fine-tuning Text Models

In contrast to stype.text_embedded, stype.text_tokenized does minimal processing at the dataset materialization stage by only tokenizing raw texts, i.e., transforming strings into sequences of integers. Then, during the training stage, the fully-fledged text models take the tokenized sentences as input and output text embeddings, which allows the text models to be trained in an end-to-end manner.

Here, we use the Transformers package.

pip install transformers

Specifying Text Tokenization

In stype.text_tokenized, text columns will be tokenized during the dataset materialization stage. Let’s first create a tokenization class that tokenizes a list of strings to a dictionary of torch.Tensor.

from typing import List
from transformers import AutoTokenizer
from torch_frame.typing import TextTokenizationOutputs

class TextToEmbeddingTokenization:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

    def __call__(self, sentences: List[str]) -> TextTokenizationOutputs:
        # Tokenize batches of sentences
        return self.tokenizer(
            sentences,
            truncation=True,
            padding=True,
            return_tensors='pt',
        )

Here, the output TextTokenizationOutputs is a dictionary, where the keys include input_ids and attention_mask, and the values contain tensors of tokens and attention masks.

Then we instantiate TextTokenizerConfig for our text embedding model as follows.

from torch_frame.config.text_tokenizer import TextTokenizerConfig

col_to_text_tokenizer_cfg = TextTokenizerConfig(
    text_tokenizer=TextToEmbeddingTokenization(),
    batch_size=10_000,
)

Here text_tokenizer maps a list of sentences into a dictionary of torch.Tensor, which are input to text models at training time. Tokenization is processed in mini-batch, where batch_size represents the batch size. Because text tokenizer runs fast on CPU, we can specify relatively large batch_size here. Also, note that we allow to specify a dictionary of text_tokenizer for different text columns with stype.text_tokenized.

# Prepare text_tokenizer0 and text_tokenizer1 for text_col0 and text_col1, respectively.
col_to_text_tokenizer_cfg = {
    "text_col0":
    TextTokenizerConfig(text_tokenizer=text_tokenizer0, batch_size=10000),
    "text_col1":
    TextTokenizerConfig(text_tokenizer=text_tokenizer1, batch_size=20000),
}

Tokenizing Text Columns for a Dataset

Once col_to_text_tokenizer_cfg is specified, we can pass it to Dataset object as follows.

import torch_frame
from torch_frame.datasets import MultimodalTextBenchmark

dataset = MultimodalTextBenchmark(
    root='/tmp/multimodal_text_benchmark/wine_reviews',
    name='wine_reviews',
    text_stype=torch_frame.text_tokenized,
    col_to_text_tokenizer_cfg=col_to_text_tokenizer_cfg,
)

dataset.col_to_stype['description']
>>> <stype.text_tokenized: 'text_tokenized'>

We then call dataset.materialize(), which will use the text tokenizers to pre-tokenize text_tokenized columns based on the given col_to_text_tokenizer_cfg.

# Pre-encode text columns based on col_to_text_tokenizer_cfg.
dataset.materialize()

# A dictionary of text tokenization results
dataset.tensor_frame.feat_dict[torch_frame.text_tokenized]
>>> {'input_ids': MultiNestedTensor(num_rows=105154, num_cols=1, device='cpu'), 'attention_mask': MultiNestedTensor(num_rows=105154, num_cols=1, device='cpu')}

Notice that we use a dictionary of MultiNestedTensor to store the tokenized results. The reason we use dictionary is that common text tokenizers usually return multiple text model inputs such as input_ids and attention_mask as shown before.

Finetuning Text Models with Tabular Learning

offers LinearModelEncoder designed to flexibly apply any learnable module in per-column manner. We first specify ModelConfig object that declares the module to apply to each column.

Note

ModelConfig has two arguments to specify: First, model is a learnable module that takes per-column tensors in TensorFrame as input and outputs per-column embeddings. Formally, model takes a TensorData object of shape [batch_size, 1, *] as input and then outputs embeddings of shape [batch_size, 1, out_channels]. Then, out_channels specifies the output embedding dimensionality of model.

We can use the above LinearModelEncoder functionality for embedding stype.text_tokenized within TensorFrame.

To use the functionality, let us first prepare model for ModelConfig. Here we use PEFT package and the LoRA strategy to finetune the underlying text model.

pip install peft

We then design model as a DistilBERT with LoRA finetuning. Note that model needs to take the per-column feat as input and outputs embeddings of size [batch_size, 1, out_channels]. As we mentioned, the per-column feat is in the format of dictionary of MultiNestedTensor in the case of stype.text_tokenized. During the forward(), we first transform each MultiNestedTensor into padded torch.Tensor by using to_dense() with the padding value specified by fill_value.

import torch
from torch import Tensor
from transformers import AutoModel
from torch_frame.data import MultiNestedTensor
from peft import LoraConfig, TaskType, get_peft_model

class TextToEmbeddingFinetune(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoModel.from_pretrained('distilbert-base-uncased')
        # Set LoRA config
        peft_config = LoraConfig(
            task_type=TaskType.FEATURE_EXTRACTION,
            r=32,
            lora_alpha=32,
            inference_mode=False,
            lora_dropout=0.1,
            bias="none",
            target_modules=["ffn.lin1"],
        )
        # Update the model with LoRA config
        self.model = get_peft_model(self.model, peft_config)

    def forward(self, feat: dict[str, MultiNestedTensor]) -> Tensor:
        # Pad [batch_size, 1, *] into [batch_size, 1, batch_max_seq_len], then,
        # squeeze to [batch_size, batch_max_seq_len].
        input_ids = feat["input_ids"].to_dense(fill_value=0).squeeze(dim=1)
        # Set attention_mask of padding idx to be False
        mask = feat["attention_mask"].to_dense(fill_value=0).squeeze(dim=1)

        # Get text embeddings for each text tokenized column
        # out.last_hidden_state has the shape:
        # [batch_size, batch_max_seq_len, out_channels]
        out = self.model(input_ids=input_ids, attention_mask=mask)

        # Use the CLS embedding to represent the sentence embedding
        # Return value has the shape [batch_size, 1, out_channels]
        return out.last_hidden_state[:, 0, :].unsqueeze(1)

Now we have prepared model. We can instantiate the ModelConfig object by additionally supplying out_channels argument. In the case of DistilBERT, out_channels is 768.

from torch_frame.config import ModelConfig
model_cfg = ModelConfig(model=TextToEmbeddingFinetune(), out_channels=768)

We then specify col_to_model_cfg, mapping each column name into a desired model_cfg.

col_to_model_cfg = {"description": model_cfg}

We can now pass col_to_model_cfg to LinearModelEncoder so that it applies the specified model to the desired column. In this case, we apply the model TextToEmbeddingFinetune to the stype.text_tokenized column called "description" within TensorFrame.

from torch_frame.nn import (
    EmbeddingEncoder,
    LinearEncoder,
    LinearModelEncoder,
)

stype_encoder_dict = {
    stype.categorical: EmbeddingEncoder(),
    stype.numerical: LinearEncoder(),
    stype.text_tokenized: LinearModelEncoder(col_to_model_cfg=col_to_model_cfg),
}

The resulting stype_encoder_dict can be directly fed into StypeWiseFeatureEncoder.

Please refer to the pytorch-frame/examples/transformers_text.py for more text embedding and finetuning information with Transformers package.

Also, please refer to the pytorch-frame/examples/llm_embedding.py for more text embedding information with large language models such as OpenAI embeddings and Cohere embed.