from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from PIL import Image
from torch import Tensor
[docs]class ImageEmbedder(ABC):
r"""Parent class for the :obj:`image_embedder` of
:class:`ImageEmbedderConfig`. This class first retrieves images based
on given paths stored in the data frame and then embeds retrieved images
into tensor. Users are responsible for implementing :meth:`forward_embed`
which takes a list of images and returns embeddings tensor. User can also
override :meth:`forward_retrieve` which takes the paths to images and
return a list of :obj:`PIL.Image.Image`.
"""
[docs] def forward_retrieve(self, path_to_images: list[str]) -> list[Image.Image]:
r"""Retrieval function that reads a list of images from
a list of file paths with the :obj:`RGB` mode.
"""
images: list[Image.Image] = []
for path_to_image in path_to_images:
image = Image.open(path_to_image)
images.append(image.copy())
image.close()
images = [image.convert('RGB') for image in images]
return images
[docs] @abstractmethod
def forward_embed(self, images: list[Image.Image]) -> Tensor:
r"""Embedding function that takes a list of images and returns
an embedding tensor.
"""
raise NotImplementedError
def __call__(self, path_to_images: list[str]) -> Tensor:
images = self.forward_retrieve(path_to_images)
return self.forward_embed(images)
[docs]@dataclass
class ImageEmbedderConfig:
r"""Image embedder model that maps a list of images into PyTorch
Tensor embeddings.
Args:
image_embedder (callable): A callable image embedder that takes a
list of path to images as input and outputs the PyTorch Tensor
embeddings for that list of images. Usually it contains a retriever
to load image files and then a embedder converting images to
embeddings.
batch_size (int, optional): Batch size to use when encoding the
images. If set to :obj:`None`, the image embeddings will
be obtained in a full-batch manner. (default: :obj:`None`)
"""
image_embedder: Callable[[list[str]], Tensor]
# Batch size to use when encoding the images. It is recommended to set
# it to a reasonable value when one uses a heavy image embedding model
# (e.g., ViT) on GPU. If set to :obj:`None`, the image embeddings
# will be obtained in a full-batch manner.
batch_size: int | None = None