|
"""Caching loader for the 20 newsgroups text classification dataset. |
|
|
|
|
|
The description of the dataset is available on the official website at: |
|
|
|
http://people.csail.mit.edu/jrennie/20Newsgroups/ |
|
|
|
Quoting the introduction: |
|
|
|
The 20 Newsgroups data set is a collection of approximately 20,000 |
|
newsgroup documents, partitioned (nearly) evenly across 20 different |
|
newsgroups. To the best of my knowledge, it was originally collected |
|
by Ken Lang, probably for his Newsweeder: Learning to filter netnews |
|
paper, though he does not explicitly mention this collection. The 20 |
|
newsgroups collection has become a popular data set for experiments |
|
in text applications of machine learning techniques, such as text |
|
classification and text clustering. |
|
|
|
This dataset loader will download the recommended "by date" variant of the |
|
dataset and which features a point in time split between the train and |
|
test sets. The compressed dataset size is around 14 Mb compressed. Once |
|
uncompressed the train set is 52 MB and the test set is 34 MB. |
|
""" |
|
|
|
|
|
|
|
|
|
import codecs |
|
import logging |
|
import os |
|
import pickle |
|
import re |
|
import shutil |
|
import tarfile |
|
from contextlib import suppress |
|
from numbers import Integral, Real |
|
|
|
import joblib |
|
import numpy as np |
|
import scipy.sparse as sp |
|
|
|
from .. import preprocessing |
|
from ..feature_extraction.text import CountVectorizer |
|
from ..utils import Bunch, check_random_state |
|
from ..utils._param_validation import Interval, StrOptions, validate_params |
|
from ..utils.fixes import tarfile_extractall |
|
from . import get_data_home, load_files |
|
from ._base import ( |
|
RemoteFileMetadata, |
|
_convert_data_dataframe, |
|
_fetch_remote, |
|
_pkl_filepath, |
|
load_descr, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
ARCHIVE = RemoteFileMetadata( |
|
filename="20news-bydate.tar.gz", |
|
url="https://ndownloader.figshare.com/files/5975967", |
|
checksum="8f1b2514ca22a5ade8fbb9cfa5727df95fa587f4c87b786e15c759fa66d95610", |
|
) |
|
|
|
CACHE_NAME = "20news-bydate.pkz" |
|
TRAIN_FOLDER = "20news-bydate-train" |
|
TEST_FOLDER = "20news-bydate-test" |
|
|
|
|
|
def _download_20newsgroups(target_dir, cache_path, n_retries, delay): |
|
"""Download the 20 newsgroups data and stored it as a zipped pickle.""" |
|
train_path = os.path.join(target_dir, TRAIN_FOLDER) |
|
test_path = os.path.join(target_dir, TEST_FOLDER) |
|
|
|
os.makedirs(target_dir, exist_ok=True) |
|
|
|
logger.info("Downloading dataset from %s (14 MB)", ARCHIVE.url) |
|
archive_path = _fetch_remote( |
|
ARCHIVE, dirname=target_dir, n_retries=n_retries, delay=delay |
|
) |
|
|
|
logger.debug("Decompressing %s", archive_path) |
|
with tarfile.open(archive_path, "r:gz") as fp: |
|
tarfile_extractall(fp, path=target_dir) |
|
|
|
with suppress(FileNotFoundError): |
|
os.remove(archive_path) |
|
|
|
|
|
cache = dict( |
|
train=load_files(train_path, encoding="latin1"), |
|
test=load_files(test_path, encoding="latin1"), |
|
) |
|
compressed_content = codecs.encode(pickle.dumps(cache), "zlib_codec") |
|
with open(cache_path, "wb") as f: |
|
f.write(compressed_content) |
|
|
|
shutil.rmtree(target_dir) |
|
return cache |
|
|
|
|
|
def strip_newsgroup_header(text): |
|
""" |
|
Given text in "news" format, strip the headers, by removing everything |
|
before the first blank line. |
|
|
|
Parameters |
|
---------- |
|
text : str |
|
The text from which to remove the signature block. |
|
""" |
|
_before, _blankline, after = text.partition("\n\n") |
|
return after |
|
|
|
|
|
_QUOTE_RE = re.compile( |
|
r"(writes in|writes:|wrote:|says:|said:" r"|^In article|^Quoted from|^\||^>)" |
|
) |
|
|
|
|
|
def strip_newsgroup_quoting(text): |
|
""" |
|
Given text in "news" format, strip lines beginning with the quote |
|
characters > or |, plus lines that often introduce a quoted section |
|
(for example, because they contain the string 'writes:'.) |
|
|
|
Parameters |
|
---------- |
|
text : str |
|
The text from which to remove the signature block. |
|
""" |
|
good_lines = [line for line in text.split("\n") if not _QUOTE_RE.search(line)] |
|
return "\n".join(good_lines) |
|
|
|
|
|
def strip_newsgroup_footer(text): |
|
""" |
|
Given text in "news" format, attempt to remove a signature block. |
|
|
|
As a rough heuristic, we assume that signatures are set apart by either |
|
a blank line or a line made of hyphens, and that it is the last such line |
|
in the file (disregarding blank lines at the end). |
|
|
|
Parameters |
|
---------- |
|
text : str |
|
The text from which to remove the signature block. |
|
""" |
|
lines = text.strip().split("\n") |
|
for line_num in range(len(lines) - 1, -1, -1): |
|
line = lines[line_num] |
|
if line.strip().strip("-") == "": |
|
break |
|
|
|
if line_num > 0: |
|
return "\n".join(lines[:line_num]) |
|
else: |
|
return text |
|
|
|
|
|
@validate_params( |
|
{ |
|
"data_home": [str, os.PathLike, None], |
|
"subset": [StrOptions({"train", "test", "all"})], |
|
"categories": ["array-like", None], |
|
"shuffle": ["boolean"], |
|
"random_state": ["random_state"], |
|
"remove": [tuple], |
|
"download_if_missing": ["boolean"], |
|
"return_X_y": ["boolean"], |
|
"n_retries": [Interval(Integral, 1, None, closed="left")], |
|
"delay": [Interval(Real, 0.0, None, closed="neither")], |
|
}, |
|
prefer_skip_nested_validation=True, |
|
) |
|
def fetch_20newsgroups( |
|
*, |
|
data_home=None, |
|
subset="train", |
|
categories=None, |
|
shuffle=True, |
|
random_state=42, |
|
remove=(), |
|
download_if_missing=True, |
|
return_X_y=False, |
|
n_retries=3, |
|
delay=1.0, |
|
): |
|
"""Load the filenames and data from the 20 newsgroups dataset \ |
|
(classification). |
|
|
|
Download it if necessary. |
|
|
|
================= ========== |
|
Classes 20 |
|
Samples total 18846 |
|
Dimensionality 1 |
|
Features text |
|
================= ========== |
|
|
|
Read more in the :ref:`User Guide <20newsgroups_dataset>`. |
|
|
|
Parameters |
|
---------- |
|
data_home : str or path-like, default=None |
|
Specify a download and cache folder for the datasets. If None, |
|
all scikit-learn data is stored in '~/scikit_learn_data' subfolders. |
|
|
|
subset : {'train', 'test', 'all'}, default='train' |
|
Select the dataset to load: 'train' for the training set, 'test' |
|
for the test set, 'all' for both, with shuffled ordering. |
|
|
|
categories : array-like, dtype=str, default=None |
|
If None (default), load all the categories. |
|
If not None, list of category names to load (other categories |
|
ignored). |
|
|
|
shuffle : bool, default=True |
|
Whether or not to shuffle the data: might be important for models that |
|
make the assumption that the samples are independent and identically |
|
distributed (i.i.d.), such as stochastic gradient descent. |
|
|
|
random_state : int, RandomState instance or None, default=42 |
|
Determines random number generation for dataset shuffling. Pass an int |
|
for reproducible output across multiple function calls. |
|
See :term:`Glossary <random_state>`. |
|
|
|
remove : tuple, default=() |
|
May contain any subset of ('headers', 'footers', 'quotes'). Each of |
|
these are kinds of text that will be detected and removed from the |
|
newsgroup posts, preventing classifiers from overfitting on |
|
metadata. |
|
|
|
'headers' removes newsgroup headers, 'footers' removes blocks at the |
|
ends of posts that look like signatures, and 'quotes' removes lines |
|
that appear to be quoting another post. |
|
|
|
'headers' follows an exact standard; the other filters are not always |
|
correct. |
|
|
|
download_if_missing : bool, default=True |
|
If False, raise an OSError if the data is not locally available |
|
instead of trying to download the data from the source site. |
|
|
|
return_X_y : bool, default=False |
|
If True, returns `(data.data, data.target)` instead of a Bunch |
|
object. |
|
|
|
.. versionadded:: 0.22 |
|
|
|
n_retries : int, default=3 |
|
Number of retries when HTTP errors are encountered. |
|
|
|
.. versionadded:: 1.5 |
|
|
|
delay : float, default=1.0 |
|
Number of seconds between retries. |
|
|
|
.. versionadded:: 1.5 |
|
|
|
Returns |
|
------- |
|
bunch : :class:`~sklearn.utils.Bunch` |
|
Dictionary-like object, with the following attributes. |
|
|
|
data : list of shape (n_samples,) |
|
The data list to learn. |
|
target: ndarray of shape (n_samples,) |
|
The target labels. |
|
filenames: list of shape (n_samples,) |
|
The path to the location of the data. |
|
DESCR: str |
|
The full description of the dataset. |
|
target_names: list of shape (n_classes,) |
|
The names of target classes. |
|
|
|
(data, target) : tuple if `return_X_y=True` |
|
A tuple of two ndarrays. The first contains a 2D array of shape |
|
(n_samples, n_classes) with each row representing one sample and each |
|
column representing the features. The second array of shape |
|
(n_samples,) contains the target samples. |
|
|
|
.. versionadded:: 0.22 |
|
|
|
Examples |
|
-------- |
|
>>> from sklearn.datasets import fetch_20newsgroups |
|
>>> cats = ['alt.atheism', 'sci.space'] |
|
>>> newsgroups_train = fetch_20newsgroups(subset='train', categories=cats) |
|
>>> list(newsgroups_train.target_names) |
|
['alt.atheism', 'sci.space'] |
|
>>> newsgroups_train.filenames.shape |
|
(1073,) |
|
>>> newsgroups_train.target.shape |
|
(1073,) |
|
>>> newsgroups_train.target[:10] |
|
array([0, 1, 1, 1, 0, 1, 1, 0, 0, 0]) |
|
""" |
|
|
|
data_home = get_data_home(data_home=data_home) |
|
cache_path = _pkl_filepath(data_home, CACHE_NAME) |
|
twenty_home = os.path.join(data_home, "20news_home") |
|
cache = None |
|
if os.path.exists(cache_path): |
|
try: |
|
with open(cache_path, "rb") as f: |
|
compressed_content = f.read() |
|
uncompressed_content = codecs.decode(compressed_content, "zlib_codec") |
|
cache = pickle.loads(uncompressed_content) |
|
except Exception as e: |
|
print(80 * "_") |
|
print("Cache loading failed") |
|
print(80 * "_") |
|
print(e) |
|
|
|
if cache is None: |
|
if download_if_missing: |
|
logger.info("Downloading 20news dataset. This may take a few minutes.") |
|
cache = _download_20newsgroups( |
|
target_dir=twenty_home, |
|
cache_path=cache_path, |
|
n_retries=n_retries, |
|
delay=delay, |
|
) |
|
else: |
|
raise OSError("20Newsgroups dataset not found") |
|
|
|
if subset in ("train", "test"): |
|
data = cache[subset] |
|
elif subset == "all": |
|
data_lst = list() |
|
target = list() |
|
filenames = list() |
|
for subset in ("train", "test"): |
|
data = cache[subset] |
|
data_lst.extend(data.data) |
|
target.extend(data.target) |
|
filenames.extend(data.filenames) |
|
|
|
data.data = data_lst |
|
data.target = np.array(target) |
|
data.filenames = np.array(filenames) |
|
|
|
fdescr = load_descr("twenty_newsgroups.rst") |
|
|
|
data.DESCR = fdescr |
|
|
|
if "headers" in remove: |
|
data.data = [strip_newsgroup_header(text) for text in data.data] |
|
if "footers" in remove: |
|
data.data = [strip_newsgroup_footer(text) for text in data.data] |
|
if "quotes" in remove: |
|
data.data = [strip_newsgroup_quoting(text) for text in data.data] |
|
|
|
if categories is not None: |
|
labels = [(data.target_names.index(cat), cat) for cat in categories] |
|
|
|
labels.sort() |
|
labels, categories = zip(*labels) |
|
mask = np.isin(data.target, labels) |
|
data.filenames = data.filenames[mask] |
|
data.target = data.target[mask] |
|
|
|
data.target = np.searchsorted(labels, data.target) |
|
data.target_names = list(categories) |
|
|
|
data_lst = np.array(data.data, dtype=object) |
|
data_lst = data_lst[mask] |
|
data.data = data_lst.tolist() |
|
|
|
if shuffle: |
|
random_state = check_random_state(random_state) |
|
indices = np.arange(data.target.shape[0]) |
|
random_state.shuffle(indices) |
|
data.filenames = data.filenames[indices] |
|
data.target = data.target[indices] |
|
|
|
data_lst = np.array(data.data, dtype=object) |
|
data_lst = data_lst[indices] |
|
data.data = data_lst.tolist() |
|
|
|
if return_X_y: |
|
return data.data, data.target |
|
|
|
return data |
|
|
|
|
|
@validate_params( |
|
{ |
|
"subset": [StrOptions({"train", "test", "all"})], |
|
"remove": [tuple], |
|
"data_home": [str, os.PathLike, None], |
|
"download_if_missing": ["boolean"], |
|
"return_X_y": ["boolean"], |
|
"normalize": ["boolean"], |
|
"as_frame": ["boolean"], |
|
"n_retries": [Interval(Integral, 1, None, closed="left")], |
|
"delay": [Interval(Real, 0.0, None, closed="neither")], |
|
}, |
|
prefer_skip_nested_validation=True, |
|
) |
|
def fetch_20newsgroups_vectorized( |
|
*, |
|
subset="train", |
|
remove=(), |
|
data_home=None, |
|
download_if_missing=True, |
|
return_X_y=False, |
|
normalize=True, |
|
as_frame=False, |
|
n_retries=3, |
|
delay=1.0, |
|
): |
|
"""Load and vectorize the 20 newsgroups dataset (classification). |
|
|
|
Download it if necessary. |
|
|
|
This is a convenience function; the transformation is done using the |
|
default settings for |
|
:class:`~sklearn.feature_extraction.text.CountVectorizer`. For more |
|
advanced usage (stopword filtering, n-gram extraction, etc.), combine |
|
fetch_20newsgroups with a custom |
|
:class:`~sklearn.feature_extraction.text.CountVectorizer`, |
|
:class:`~sklearn.feature_extraction.text.HashingVectorizer`, |
|
:class:`~sklearn.feature_extraction.text.TfidfTransformer` or |
|
:class:`~sklearn.feature_extraction.text.TfidfVectorizer`. |
|
|
|
The resulting counts are normalized using |
|
:func:`sklearn.preprocessing.normalize` unless normalize is set to False. |
|
|
|
================= ========== |
|
Classes 20 |
|
Samples total 18846 |
|
Dimensionality 130107 |
|
Features real |
|
================= ========== |
|
|
|
Read more in the :ref:`User Guide <20newsgroups_dataset>`. |
|
|
|
Parameters |
|
---------- |
|
subset : {'train', 'test', 'all'}, default='train' |
|
Select the dataset to load: 'train' for the training set, 'test' |
|
for the test set, 'all' for both, with shuffled ordering. |
|
|
|
remove : tuple, default=() |
|
May contain any subset of ('headers', 'footers', 'quotes'). Each of |
|
these are kinds of text that will be detected and removed from the |
|
newsgroup posts, preventing classifiers from overfitting on |
|
metadata. |
|
|
|
'headers' removes newsgroup headers, 'footers' removes blocks at the |
|
ends of posts that look like signatures, and 'quotes' removes lines |
|
that appear to be quoting another post. |
|
|
|
data_home : str or path-like, default=None |
|
Specify an download and cache folder for the datasets. If None, |
|
all scikit-learn data is stored in '~/scikit_learn_data' subfolders. |
|
|
|
download_if_missing : bool, default=True |
|
If False, raise an OSError if the data is not locally available |
|
instead of trying to download the data from the source site. |
|
|
|
return_X_y : bool, default=False |
|
If True, returns ``(data.data, data.target)`` instead of a Bunch |
|
object. |
|
|
|
.. versionadded:: 0.20 |
|
|
|
normalize : bool, default=True |
|
If True, normalizes each document's feature vector to unit norm using |
|
:func:`sklearn.preprocessing.normalize`. |
|
|
|
.. versionadded:: 0.22 |
|
|
|
as_frame : bool, default=False |
|
If True, the data is a pandas DataFrame including columns with |
|
appropriate dtypes (numeric, string, or categorical). The target is |
|
a pandas DataFrame or Series depending on the number of |
|
`target_columns`. |
|
|
|
.. versionadded:: 0.24 |
|
|
|
n_retries : int, default=3 |
|
Number of retries when HTTP errors are encountered. |
|
|
|
.. versionadded:: 1.5 |
|
|
|
delay : float, default=1.0 |
|
Number of seconds between retries. |
|
|
|
.. versionadded:: 1.5 |
|
|
|
Returns |
|
------- |
|
bunch : :class:`~sklearn.utils.Bunch` |
|
Dictionary-like object, with the following attributes. |
|
|
|
data: {sparse matrix, dataframe} of shape (n_samples, n_features) |
|
The input data matrix. If ``as_frame`` is `True`, ``data`` is |
|
a pandas DataFrame with sparse columns. |
|
target: {ndarray, series} of shape (n_samples,) |
|
The target labels. If ``as_frame`` is `True`, ``target`` is a |
|
pandas Series. |
|
target_names: list of shape (n_classes,) |
|
The names of target classes. |
|
DESCR: str |
|
The full description of the dataset. |
|
frame: dataframe of shape (n_samples, n_features + 1) |
|
Only present when `as_frame=True`. Pandas DataFrame with ``data`` |
|
and ``target``. |
|
|
|
.. versionadded:: 0.24 |
|
|
|
(data, target) : tuple if ``return_X_y`` is True |
|
`data` and `target` would be of the format defined in the `Bunch` |
|
description above. |
|
|
|
.. versionadded:: 0.20 |
|
|
|
Examples |
|
-------- |
|
>>> from sklearn.datasets import fetch_20newsgroups_vectorized |
|
>>> newsgroups_vectorized = fetch_20newsgroups_vectorized(subset='test') |
|
>>> newsgroups_vectorized.data.shape |
|
(7532, 130107) |
|
>>> newsgroups_vectorized.target.shape |
|
(7532,) |
|
""" |
|
data_home = get_data_home(data_home=data_home) |
|
filebase = "20newsgroup_vectorized" |
|
if remove: |
|
filebase += "remove-" + "-".join(remove) |
|
target_file = _pkl_filepath(data_home, filebase + ".pkl") |
|
|
|
|
|
data_train = fetch_20newsgroups( |
|
data_home=data_home, |
|
subset="train", |
|
categories=None, |
|
shuffle=True, |
|
random_state=12, |
|
remove=remove, |
|
download_if_missing=download_if_missing, |
|
n_retries=n_retries, |
|
delay=delay, |
|
) |
|
|
|
data_test = fetch_20newsgroups( |
|
data_home=data_home, |
|
subset="test", |
|
categories=None, |
|
shuffle=True, |
|
random_state=12, |
|
remove=remove, |
|
download_if_missing=download_if_missing, |
|
n_retries=n_retries, |
|
delay=delay, |
|
) |
|
|
|
if os.path.exists(target_file): |
|
try: |
|
X_train, X_test, feature_names = joblib.load(target_file) |
|
except ValueError as e: |
|
raise ValueError( |
|
f"The cached dataset located in {target_file} was fetched " |
|
"with an older scikit-learn version and it is not compatible " |
|
"with the scikit-learn version imported. You need to " |
|
f"manually delete the file: {target_file}." |
|
) from e |
|
else: |
|
vectorizer = CountVectorizer(dtype=np.int16) |
|
X_train = vectorizer.fit_transform(data_train.data).tocsr() |
|
X_test = vectorizer.transform(data_test.data).tocsr() |
|
feature_names = vectorizer.get_feature_names_out() |
|
|
|
joblib.dump((X_train, X_test, feature_names), target_file, compress=9) |
|
|
|
|
|
|
|
if normalize: |
|
X_train = X_train.astype(np.float64) |
|
X_test = X_test.astype(np.float64) |
|
preprocessing.normalize(X_train, copy=False) |
|
preprocessing.normalize(X_test, copy=False) |
|
|
|
target_names = data_train.target_names |
|
|
|
if subset == "train": |
|
data = X_train |
|
target = data_train.target |
|
elif subset == "test": |
|
data = X_test |
|
target = data_test.target |
|
elif subset == "all": |
|
data = sp.vstack((X_train, X_test)).tocsr() |
|
target = np.concatenate((data_train.target, data_test.target)) |
|
|
|
fdescr = load_descr("twenty_newsgroups.rst") |
|
|
|
frame = None |
|
target_name = ["category_class"] |
|
|
|
if as_frame: |
|
frame, data, target = _convert_data_dataframe( |
|
"fetch_20newsgroups_vectorized", |
|
data, |
|
target, |
|
feature_names, |
|
target_names=target_name, |
|
sparse_data=True, |
|
) |
|
|
|
if return_X_y: |
|
return data, target |
|
|
|
return Bunch( |
|
data=data, |
|
target=target, |
|
frame=frame, |
|
target_names=target_names, |
|
feature_names=feature_names, |
|
DESCR=fdescr, |
|
) |
|
|