Source code for torch_frame.datasets.mercari

from __future__ import annotations

import os.path as osp

import pandas as pd

import torch_frame
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.utils.split import SPLIT_TO_NUM

SPLIT_COL = 'split_col'


[docs]class Mercari(torch_frame.data.Dataset): r"""The `Mercari Price Suggestion Challenge <https://www.kaggle.com/c/mercari-price-suggestion-challenge/>`_ dataset from Kaggle. Args: num_rows (int, optional): Number of rows to subsample. (default: :obj:`None`) **STATS:** .. list-table:: :widths: 10 10 10 10 20 10 :header-rows: 1 * - #rows - #cols (numerical) - #cols (categorical) - #cols (text_embedded) - Task - Missing value ratio * - 1,482,535 - 1 - 4 - 2 - regression - 0.0% """ base_url = 'https://data.pyg.org/datasets/tables/mercari_price_suggestion/' files = ['train', 'test_stg2'] def __init__( self, root: str, num_rows: int | None = None, col_to_text_embedder_cfg: dict[str, TextEmbedderConfig] | TextEmbedderConfig | None = None, ) -> None: col_to_stype = { 'name': torch_frame.text_embedded, 'item_condition_id': torch_frame.categorical, 'category_name': torch_frame.multicategorical, 'brand_name': torch_frame.categorical, 'price': torch_frame.numerical, 'shipping': torch_frame.categorical, 'item_description': torch_frame.text_embedded, } train_path = osp.join(self.base_url, 'train.csv') self.download_url(train_path, root) df_train = pd.read_csv(train_path) test_path = osp.join(self.base_url, 'test_stg2.csv') self.download_url(test_path, root) df_test = pd.read_csv(test_path) df_train[SPLIT_COL] = SPLIT_TO_NUM['train'] df_test[SPLIT_COL] = SPLIT_TO_NUM['test'] df = pd.concat([df_train, df_test], axis=0, ignore_index=True) if num_rows is not None: df = df.head(num_rows) df.drop(['train_id'], axis=1, inplace=True) super().__init__(df, col_to_stype, target_col='price', col_to_sep="/", col_to_text_embedder_cfg=col_to_text_embedder_cfg, split_col=SPLIT_COL)