Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Download and extract the MovieLens dataset from GroupLens website. | |
Download the dataset, and perform basic preprocessing. | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import sys | |
import tempfile | |
import zipfile | |
# pylint: disable=g-bad-import-order | |
# Import libraries | |
import numpy as np | |
import pandas as pd | |
import six | |
from six.moves import urllib # pylint: disable=redefined-builtin | |
from absl import app | |
from absl import flags | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
# pylint: enable=g-bad-import-order | |
from official.utils.flags import core as flags_core | |
ML_1M = "ml-1m" | |
ML_20M = "ml-20m" | |
DATASETS = [ML_1M, ML_20M] | |
RATINGS_FILE = "ratings.csv" | |
MOVIES_FILE = "movies.csv" | |
# URL to download dataset | |
_DATA_URL = "https://files.grouplens.org/datasets/movielens/" | |
GENRE_COLUMN = "genres" | |
ITEM_COLUMN = "item_id" # movies | |
RATING_COLUMN = "rating" | |
TIMESTAMP_COLUMN = "timestamp" | |
TITLE_COLUMN = "titles" | |
USER_COLUMN = "user_id" | |
GENRES = [ | |
'Action', 'Adventure', 'Animation', "Children", 'Comedy', 'Crime', | |
'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', "IMAX", 'Musical', | |
'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western' | |
] | |
N_GENRE = len(GENRES) | |
RATING_COLUMNS = [USER_COLUMN, ITEM_COLUMN, RATING_COLUMN, TIMESTAMP_COLUMN] | |
MOVIE_COLUMNS = [ITEM_COLUMN, TITLE_COLUMN, GENRE_COLUMN] | |
# Note: Users are indexed [1, k], not [0, k-1] | |
NUM_USER_IDS = { | |
ML_1M: 6040, | |
ML_20M: 138493, | |
} | |
# Note: Movies are indexed [1, k], not [0, k-1] | |
# Both the 1m and 20m datasets use the same movie set. | |
NUM_ITEM_IDS = 3952 | |
MAX_RATING = 5 | |
NUM_RATINGS = { | |
ML_1M: 1000209, | |
ML_20M: 20000263 | |
} | |
DATASET_TO_NUM_USERS_AND_ITEMS = {ML_1M: (6040, 3706), ML_20M: (138493, 26744)} | |
def _download_and_clean(dataset, data_dir): | |
"""Download MovieLens dataset in a standard format. | |
This function downloads the specified MovieLens format and coerces it into a | |
standard format. The only difference between the ml-1m and ml-20m datasets | |
after this point (other than size, of course) is that the 1m dataset uses | |
whole number ratings while the 20m dataset allows half integer ratings. | |
""" | |
if dataset not in DATASETS: | |
raise ValueError("dataset {} is not in {{{}}}".format( | |
dataset, ",".join(DATASETS))) | |
data_subdir = os.path.join(data_dir, dataset) | |
expected_files = ["{}.zip".format(dataset), RATINGS_FILE, MOVIES_FILE] | |
tf.io.gfile.makedirs(data_subdir) | |
if set(expected_files).intersection( | |
tf.io.gfile.listdir(data_subdir)) == set(expected_files): | |
logging.info("Dataset {} has already been downloaded".format(dataset)) | |
return | |
url = "{}{}.zip".format(_DATA_URL, dataset) | |
temp_dir = tempfile.mkdtemp() | |
try: | |
zip_path = os.path.join(temp_dir, "{}.zip".format(dataset)) | |
zip_path, _ = urllib.request.urlretrieve(url, zip_path) | |
statinfo = os.stat(zip_path) | |
# A new line to clear the carriage return from download progress | |
# logging.info is not applicable here | |
print() | |
logging.info( | |
"Successfully downloaded {} {} bytes".format( | |
zip_path, statinfo.st_size)) | |
zipfile.ZipFile(zip_path, "r").extractall(temp_dir) | |
if dataset == ML_1M: | |
_regularize_1m_dataset(temp_dir) | |
else: | |
_regularize_20m_dataset(temp_dir) | |
for fname in tf.io.gfile.listdir(temp_dir): | |
if not tf.io.gfile.exists(os.path.join(data_subdir, fname)): | |
tf.io.gfile.copy(os.path.join(temp_dir, fname), | |
os.path.join(data_subdir, fname)) | |
else: | |
logging.info("Skipping copy of {}, as it already exists in the " | |
"destination folder.".format(fname)) | |
finally: | |
tf.io.gfile.rmtree(temp_dir) | |
def _transform_csv(input_path, output_path, names, skip_first, separator=","): | |
"""Transform csv to a regularized format. | |
Args: | |
input_path: The path of the raw csv. | |
output_path: The path of the cleaned csv. | |
names: The csv column names. | |
skip_first: Boolean of whether to skip the first line of the raw csv. | |
separator: Character used to separate fields in the raw csv. | |
""" | |
if six.PY2: | |
names = [six.ensure_text(n, "utf-8") for n in names] | |
with tf.io.gfile.GFile(output_path, "wb") as f_out, \ | |
tf.io.gfile.GFile(input_path, "rb") as f_in: | |
# Write column names to the csv. | |
f_out.write(",".join(names).encode("utf-8")) | |
f_out.write(b"\n") | |
for i, line in enumerate(f_in): | |
if i == 0 and skip_first: | |
continue # ignore existing labels in the csv | |
line = six.ensure_text(line, "utf-8", errors="ignore") | |
fields = line.split(separator) | |
if separator != ",": | |
fields = ['"{}"'.format(field) if "," in field else field | |
for field in fields] | |
f_out.write(",".join(fields).encode("utf-8")) | |
def _regularize_1m_dataset(temp_dir): | |
""" | |
ratings.dat | |
The file has no header row, and each line is in the following format: | |
UserID::MovieID::Rating::Timestamp | |
- UserIDs range from 1 and 6040 | |
- MovieIDs range from 1 and 3952 | |
- Ratings are made on a 5-star scale (whole-star ratings only) | |
- Timestamp is represented in seconds since midnight Coordinated Universal | |
Time (UTC) of January 1, 1970. | |
- Each user has at least 20 ratings | |
movies.dat | |
Each line has the following format: | |
MovieID::Title::Genres | |
- MovieIDs range from 1 and 3952 | |
""" | |
working_dir = os.path.join(temp_dir, ML_1M) | |
_transform_csv( | |
input_path=os.path.join(working_dir, "ratings.dat"), | |
output_path=os.path.join(temp_dir, RATINGS_FILE), | |
names=RATING_COLUMNS, skip_first=False, separator="::") | |
_transform_csv( | |
input_path=os.path.join(working_dir, "movies.dat"), | |
output_path=os.path.join(temp_dir, MOVIES_FILE), | |
names=MOVIE_COLUMNS, skip_first=False, separator="::") | |
tf.io.gfile.rmtree(working_dir) | |
def _regularize_20m_dataset(temp_dir): | |
""" | |
ratings.csv | |
Each line of this file after the header row represents one rating of one | |
movie by one user, and has the following format: | |
userId,movieId,rating,timestamp | |
- The lines within this file are ordered first by userId, then, within user, | |
by movieId. | |
- Ratings are made on a 5-star scale, with half-star increments | |
(0.5 stars - 5.0 stars). | |
- Timestamps represent seconds since midnight Coordinated Universal Time | |
(UTC) of January 1, 1970. | |
- All the users had rated at least 20 movies. | |
movies.csv | |
Each line has the following format: | |
MovieID,Title,Genres | |
- MovieIDs range from 1 and 3952 | |
""" | |
working_dir = os.path.join(temp_dir, ML_20M) | |
_transform_csv( | |
input_path=os.path.join(working_dir, "ratings.csv"), | |
output_path=os.path.join(temp_dir, RATINGS_FILE), | |
names=RATING_COLUMNS, skip_first=True, separator=",") | |
_transform_csv( | |
input_path=os.path.join(working_dir, "movies.csv"), | |
output_path=os.path.join(temp_dir, MOVIES_FILE), | |
names=MOVIE_COLUMNS, skip_first=True, separator=",") | |
tf.io.gfile.rmtree(working_dir) | |
def download(dataset, data_dir): | |
if dataset: | |
_download_and_clean(dataset, data_dir) | |
else: | |
_ = [_download_and_clean(d, data_dir) for d in DATASETS] | |
def ratings_csv_to_dataframe(data_dir, dataset): | |
with tf.io.gfile.GFile(os.path.join(data_dir, dataset, RATINGS_FILE)) as f: | |
return pd.read_csv(f, encoding="utf-8") | |
def csv_to_joint_dataframe(data_dir, dataset): | |
ratings = ratings_csv_to_dataframe(data_dir, dataset) | |
with tf.io.gfile.GFile(os.path.join(data_dir, dataset, MOVIES_FILE)) as f: | |
movies = pd.read_csv(f, encoding="utf-8") | |
df = ratings.merge(movies, on=ITEM_COLUMN) | |
df[RATING_COLUMN] = df[RATING_COLUMN].astype(np.float32) | |
return df | |
def integerize_genres(dataframe): | |
"""Replace genre string with a binary vector. | |
Args: | |
dataframe: a pandas dataframe of movie data. | |
Returns: | |
The transformed dataframe. | |
""" | |
def _map_fn(entry): | |
entry.replace("Children's", "Children") # naming difference. | |
movie_genres = entry.split("|") | |
output = np.zeros((len(GENRES),), dtype=np.int64) | |
for i, genre in enumerate(GENRES): | |
if genre in movie_genres: | |
output[i] = 1 | |
return output | |
dataframe[GENRE_COLUMN] = dataframe[GENRE_COLUMN].apply(_map_fn) | |
return dataframe | |
def define_flags(): | |
"""Add flags specifying data usage arguments.""" | |
flags.DEFINE_enum( | |
name="dataset", | |
default=None, | |
enum_values=DATASETS, | |
case_sensitive=False, | |
help=flags_core.help_wrap("Dataset to be trained and evaluated.")) | |
def define_data_download_flags(): | |
"""Add flags specifying data download and usage arguments.""" | |
flags.DEFINE_string( | |
name="data_dir", default="/tmp/movielens-data/", | |
help=flags_core.help_wrap( | |
"Directory to download and extract data.")) | |
define_flags() | |
def main(_): | |
"""Download and extract the data from GroupLens website.""" | |
download(flags.FLAGS.dataset, flags.FLAGS.data_dir) | |
if __name__ == "__main__": | |
define_data_download_flags() | |
FLAGS = flags.FLAGS | |
app.run(main) | |