from __future__ import annotations
import os
import os.path as osp
import random
import string
from datetime import datetime, timedelta
import numpy as np
import pandas as pd
from PIL import Image
import torch_frame
from torch_frame import stype
from torch_frame.config.image_embedder import ImageEmbedderConfig
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.config.text_tokenizer import TextTokenizerConfig
from torch_frame.typing import TaskType
from torch_frame.utils.split import SPLIT_TO_NUM
TIME_FORMATS = ['%Y-%m-%d %H:%M:%S', '%Y-%m-%d', '%Y/%m/%d']
def _random_timestamp(start: datetime, end: datetime, format: str) -> str:
r"""This function will return a random datetime converted to string with
given format between the start and end datetime objects.
"""
timestamp = start + timedelta(
# Get a random amount of seconds between `start` and `end`
seconds=random.randint(0, int((end - start).total_seconds())), )
return timestamp.strftime(format)
def _generate_random_string(min_length: int, max_length: int) -> str:
length = random.randint(min_length, max_length)
random_string = ''.join(
random.choice(string.ascii_letters) for _ in range(length))
return random_string
[docs]class FakeDataset(torch_frame.data.Dataset):
r"""A fake dataset for testing purpose.
Args:
num_rows (int): Number of rows.
with_nan (bool): Whether include nan in the dataset.
stypes (List[stype]): List of stype columns to include
in the dataset. Particularly useful, when you want to
create a dataset with only numerical or categorical
feature columns. (default: [stype.categorical,
stype.numerical])
create_split (bool): Whether to create a train, val and test
split for the fake dataset. (default: :obj:`False`)
task_type (TaskType): Task type (default: :obj:`TaskType.REGRESSION`)
tmp_path (str, optional): Temporary path to save created images.
"""
def __init__(
self,
num_rows: int,
with_nan: bool = False,
stypes: list[stype] | None = None,
create_split: bool = False,
task_type: TaskType = TaskType.REGRESSION,
col_to_text_embedder_cfg: dict[str, TextEmbedderConfig]
| TextEmbedderConfig | None = None,
col_to_text_tokenizer_cfg: dict[str, TextTokenizerConfig]
| TextTokenizerConfig | None = None,
col_to_image_embedder_cfg: dict[str, ImageEmbedderConfig]
| ImageEmbedderConfig | None = None,
tmp_path: str | None = None,
) -> None:
stypes = stypes or [stype.categorical, stype.numerical]
assert len(stypes) > 0
df_dict: dict[str, list | np.ndarray]
arr: list | np.ndarray
if task_type == TaskType.REGRESSION:
arr = np.random.randn(num_rows)
if with_nan:
arr[0::2] = np.nan
df_dict = {'target': np.random.randn(num_rows)}
col_to_stype = {'target': stype.numerical}
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
labels = np.random.randint(0, 3, size=(num_rows, ))
if num_rows < 3:
raise ValueError("Number of rows needs to be at "
"least 3 for multiclass classification")
# make sure every label exists
labels[0] = 0
labels[1] = 1
labels[2] = 2
df_dict = {'target': labels}
col_to_stype = {'target': stype.categorical}
elif task_type == TaskType.BINARY_CLASSIFICATION:
labels = np.random.randint(0, 2, size=(num_rows, ))
if num_rows < 2:
raise ValueError("Number of rows needs to be at "
"least 2 for binary classification")
labels[0] = 0
labels[1] = 1
df_dict = {'target': labels}
col_to_stype = {'target': stype.categorical}
else:
raise ValueError(
"FakeDataset only support binary classification, "
"multiclass classification or regression type, but"
f" got {task_type}")
if stype.numerical in stypes:
for col_name in ['num_1', 'num_2', 'num_3']:
arr = np.random.randn(num_rows)
if with_nan:
arr[0::2] = np.nan
df_dict[col_name] = arr
col_to_stype[col_name] = stype.numerical
if stype.categorical in stypes:
for col_name in ['cat_1', 'cat_2']:
arr = np.random.randint(0, 3, size=(num_rows, ))
if with_nan:
arr = arr.astype(np.float32)
arr[1::2] = np.nan
df_dict[col_name] = arr
col_to_stype[col_name] = stype.categorical
if stype.multicategorical in stypes:
for col_name in [
'multicat_1', 'multicat_2', 'multicat_3', 'multicat_4'
]:
vocab = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
arr = []
for _ in range(num_rows):
sampled = random.sample(vocab, 3)
if col_name in ['multicat_1', 'multicat_2']:
arr.append(','.join(sampled))
else:
arr.append(sampled)
if with_nan:
arr[0] = None
df_dict[col_name] = arr
col_to_stype[col_name] = stype.multicategorical
if stype.sequence_numerical in stypes:
for col_name in ['seq_num_1', 'seq_num_2']:
arr = []
for _ in range(num_rows):
sequence_length = random.randint(1, 5)
sequence = [
random.random() for _ in range(sequence_length)
]
nan_idx = random.randint(0, sequence_length - 1)
sequence[nan_idx] = np.nan
arr.append(sequence)
df_dict[col_name] = arr
if with_nan:
df_dict[col_name][0] = None
col_to_stype[col_name] = stype.sequence_numerical
if stype.text_embedded in stypes:
for col_name in ['text_embedded_1', 'text_embedded_2']:
arr = [
' '.join([
_generate_random_string(5, 15),
_generate_random_string(5, 15)
]) for _ in range(num_rows)
]
if with_nan:
arr[0::2] = len(arr[0::2]) * [np.nan]
df_dict[col_name] = arr
col_to_stype[col_name] = stype.text_embedded
if stype.text_tokenized in stypes:
for col_name in ['text_tokenized_1', 'text_tokenized_2']:
arr = [
' '.join([
_generate_random_string(5, 15),
_generate_random_string(5, 15)
]) for _ in range(num_rows)
]
if with_nan:
arr[0::2] = len(arr[0::2]) * [np.nan]
df_dict[col_name] = arr
col_to_stype[col_name] = stype.text_tokenized
if stype.embedding in stypes:
for col_name in ['emb_1', 'emb_2']:
emb_dim = random.randint(1, 5)
emb = [random.random() for _ in range(emb_dim)]
embs = [emb for _ in range(num_rows)]
df_dict[col_name] = embs
col_to_stype[col_name] = stype.embedding
if stype.timestamp in stypes:
start_date = datetime(2000, 1, 1)
end_date = datetime(2023, 1, 1)
for i in range(len(TIME_FORMATS)):
col_name = f'timestamp_{i}'
format = TIME_FORMATS[i]
arr = [
_random_timestamp(start_date, end_date, format)
for _ in range(num_rows)
]
if with_nan:
arr[0::2] = len(arr[0::2]) * [np.nan]
df_dict[col_name] = arr
col_to_stype[col_name] = stype.timestamp
if stype.image_embedded in stypes:
assert tmp_path is not None
for col_name in ['image_embedded_1', 'image_embedded_2']:
arr = []
os.makedirs(osp.join(tmp_path, col_name), exist_ok=True)
for i in range(num_rows):
img_path = osp.join(tmp_path, col_name, f'{i}.png')
img = Image.new('RGB', (24, 24))
img.save(img_path)
img.close()
arr.append(img_path)
df_dict[col_name] = arr
col_to_stype[col_name] = stype.image_embedded
df = pd.DataFrame(df_dict)
if create_split:
# TODO: Instead of having a split column name with train, val and
# test, we will implement `random_split` and `split_by_col`
# function in the Dataset class. We will modify the following lines
# when the functions are introduced.
if num_rows < 3:
raise ValueError("Dataframe needs at least 3 rows to include"
" each of train, val and test split.")
split = [SPLIT_TO_NUM['train']] * num_rows
split[1] = SPLIT_TO_NUM['val']
split[2] = SPLIT_TO_NUM['test']
df['split'] = split
super().__init__(
df,
col_to_stype,
target_col='target',
split_col='split' if create_split else None,
col_to_sep={
'multicat_1': ',',
'multicat_2': ',',
},
col_to_text_embedder_cfg=col_to_text_embedder_cfg,
col_to_text_tokenizer_cfg=col_to_text_tokenizer_cfg,
col_to_time_format={
f'timestamp_{i}': TIME_FORMATS[i]
for i in range(len(TIME_FORMATS))
},
col_to_image_embedder_cfg=col_to_image_embedder_cfg,
)