Source code for torch_frame.datasets.mushroom

import os.path as osp
import zipfile

import pandas as pd

import torch_frame


[docs]class Mushroom(torch_frame.data.Dataset): r"""The `Mushroom classification Kaggle competition <https://www.kaggle.com/datasets/uciml/mushroom-classification>`_ dataset. It's a task to predict whether a mushroom is edible or poisonous. **STATS:** .. list-table:: :widths: 10 10 10 10 20 10 :header-rows: 1 * - #rows - #cols (numerical) - #cols (categorical) - #classes - Task - Missing value ratio * - 8,124 - 0 - 22 - 2 - binary_classification - 0.0% """ url = 'http://archive.ics.uci.edu/static/public/73/mushroom.zip' def __init__(self, root: str): path = self.download_url(self.url, root) folder_path = osp.dirname(path) with zipfile.ZipFile(path, 'r') as zip_ref: zip_ref.extractall(folder_path) data_path = osp.join(folder_path, 'agaricus-lepiota.data') names = [ 'class', 'cap-shape', 'cap-surface', 'cap-color', 'bruises', 'odor', 'gill-attachment', 'gill-spacing', 'gill-size', 'gill-color', 'stalk-shape', 'stalk-root', 'stalk-surface-above-ring', 'stalk-surface-below-ring', 'stalk-color-above-ring', 'stalk-color-below-ring', 'veil-type', 'veil-color', 'ring-number', 'ring-type', 'spore-print-color', 'population', 'habitat', ] df = pd.read_csv(data_path, names=names) col_to_stype = { 'class': torch_frame.categorical, 'cap-shape': torch_frame.categorical, 'cap-surface': torch_frame.categorical, 'cap-color': torch_frame.categorical, 'bruises': torch_frame.categorical, 'odor': torch_frame.categorical, 'gill-attachment': torch_frame.categorical, 'gill-spacing': torch_frame.categorical, 'gill-size': torch_frame.categorical, 'gill-color': torch_frame.categorical, 'stalk-shape': torch_frame.categorical, 'stalk-root': torch_frame.categorical, 'stalk-surface-above-ring': torch_frame.categorical, 'stalk-surface-below-ring': torch_frame.categorical, 'stalk-color-above-ring': torch_frame.categorical, 'stalk-color-below-ring': torch_frame.categorical, 'veil-type': torch_frame.categorical, 'veil-color': torch_frame.categorical, 'ring-number': torch_frame.categorical, 'ring-type': torch_frame.categorical, 'spore-print-color': torch_frame.categorical, 'population': torch_frame.categorical, 'habitat': torch_frame.categorical, } super().__init__(df, col_to_stype, target_col='class')