Source code for torch_frame.datasets.movielens_1m

from __future__ import annotations

import os.path as osp
import zipfile

import pandas as pd

import torch_frame
from torch_frame.config.text_embedder import TextEmbedderConfig


[docs]class Movielens1M(torch_frame.data.Dataset): r"""The MovieLens 1M rating dataset, assembled by GroupLens Research from the MovieLens web site, consisting of movies (3,883 nodes) and users (6,040 nodes) with approximately 1 million ratings between them. **STATS:** .. list-table:: :widths: 10 10 10 10 20 :header-rows: 1 * - #Users - #Items - #User Field - #Item Field - #Samples * - 6040 - 3952 - 5 - 3 - 1000209 """ url = 'https://files.grouplens.org/datasets/movielens/ml-1m.zip' def __init__( self, root: str, col_to_text_embedder_cfg: dict[str, TextEmbedderConfig] | TextEmbedderConfig | None = None, ): path = self.download_url(self.url, root) folder_path = osp.dirname(path) with zipfile.ZipFile(path, 'r') as zip_ref: zip_ref.extractall(folder_path) data_path = osp.join(folder_path, 'ml-1m') users = pd.read_csv( osp.join(data_path, 'users.dat'), header=None, names=['user_id', 'gender', 'age', 'occupation', 'zip'], sep='::', engine='python', ) movies = pd.read_csv( osp.join(data_path, 'movies.dat'), header=None, names=['movie_id', 'title', 'genres'], sep='::', engine='python', encoding='ISO-8859-1', ) ratings = pd.read_csv( osp.join(data_path, 'ratings.dat'), header=None, names=['user_id', 'movie_id', 'rating', 'timestamp'], sep='::', engine='python', ) df = pd.merge(pd.merge(ratings, users), movies) \ .sort_values(by='timestamp') \ .reset_index().drop('index', axis=1) col_to_stype = { 'user_id': torch_frame.categorical, 'gender': torch_frame.categorical, 'age': torch_frame.categorical, 'occupation': torch_frame.categorical, 'zip': torch_frame.categorical, 'movie_id': torch_frame.categorical, 'title': torch_frame.text_embedded, 'genres': torch_frame.multicategorical, 'rating': torch_frame.numerical, 'timestamp': torch_frame.timestamp, } super().__init__(df, col_to_stype, target_col='rating', col_to_sep='|', col_to_text_embedder_cfg=col_to_text_embedder_cfg)