Source code for torch_frame.datasets.titanic

import pandas as pd

import torch_frame


[docs]class Titanic(torch_frame.data.Dataset): r"""The Titanic dataset from the `Titanic Kaggle competition <https://www.kaggle.com/competitions/titanic>`_. The Titanic dataset is known as the MNIST equivalent for tabular learning. The goal is to predict which passenger survived using passenger data (*i.e.* gender, age, etc). **STATS:** .. list-table:: :widths: 10 10 10 10 20 10 :header-rows: 1 * - #rows - #cols (numerical) - #cols (categorical) - #classes - Task - Missing value ratio * - 891 - 4 - 3 - 2 - binary_classification - 8.8% """ url = 'https://github.com/datasciencedojo/datasets/raw/master/titanic.csv' def __init__(self, root: str) -> None: path = self.download_url(self.url, root) df = pd.read_csv(path, index_col=['PassengerId']) col_to_stype = { # TODO Use 'Name', 'Ticket' and 'Cabin'. 'Survived': torch_frame.categorical, 'Pclass': torch_frame.categorical, 'Sex': torch_frame.categorical, 'Age': torch_frame.numerical, 'SibSp': torch_frame.numerical, 'Parch': torch_frame.numerical, 'Fare': torch_frame.numerical, 'Embarked': torch_frame.categorical, } super().__init__(df, col_to_stype, target_col='Survived')