Source code for torch_frame.config.text_embedder

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass

from torch import Tensor


[docs]@dataclass class TextEmbedderConfig: r"""Text embedder model that maps a list of strings/sentences into PyTorch Tensor embeddings. Args: text_embedder (callable): A callable text embedder that takes a list of strings as input and outputs the PyTorch Tensor embeddings for that list of strings. batch_size (int, optional): Batch size to use when encoding the sentences. If set to :obj:`None`, the text embeddings will be obtained in a full-batch manner. (default: :obj:`None`) """ text_embedder: Callable[[list[str]], Tensor] # Batch size to use when encoding the sentences. It is recommended to set # it to a reasonable value when one uses a heavy text embedding model # (e.g., Transformer) on GPU. If set to :obj:`None`, the text embeddings # will be obtained in a full-batch manner. batch_size: int | None = None