|
import re |
|
from inspect import signature |
|
from typing import Optional |
|
|
|
import pytest |
|
|
|
|
|
from sklearn.experimental import ( |
|
enable_halving_search_cv, |
|
enable_iterative_imputer, |
|
) |
|
from sklearn.utils.discovery import all_displays, all_estimators, all_functions |
|
|
|
numpydoc_validation = pytest.importorskip("numpydoc.validate") |
|
|
|
|
|
def get_all_methods(): |
|
estimators = all_estimators() |
|
displays = all_displays() |
|
for name, Klass in estimators + displays: |
|
if name.startswith("_"): |
|
|
|
continue |
|
methods = [] |
|
for name in dir(Klass): |
|
if name.startswith("_"): |
|
continue |
|
method_obj = getattr(Klass, name) |
|
if hasattr(method_obj, "__call__") or isinstance(method_obj, property): |
|
methods.append(name) |
|
methods.append(None) |
|
|
|
for method in sorted(methods, key=str): |
|
yield Klass, method |
|
|
|
|
|
def get_all_functions_names(): |
|
functions = all_functions() |
|
for _, func in functions: |
|
|
|
if "utils.fixes" not in func.__module__: |
|
yield f"{func.__module__}.{func.__name__}" |
|
|
|
|
|
def filter_errors(errors, method, Klass=None): |
|
""" |
|
Ignore some errors based on the method type. |
|
|
|
These rules are specific for scikit-learn.""" |
|
for code, message in errors: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if code in ["RT02", "GL01", "GL02"]: |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if code in ("PR02", "GL08") and Klass is not None and method is not None: |
|
method_obj = getattr(Klass, method) |
|
if isinstance(method_obj, property): |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if method is not None and code in ["EX01", "SA01", "ES01"]: |
|
continue |
|
yield code, message |
|
|
|
|
|
def repr_errors(res, Klass=None, method: Optional[str] = None) -> str: |
|
"""Pretty print original docstring and the obtained errors |
|
|
|
Parameters |
|
---------- |
|
res : dict |
|
result of numpydoc.validate.validate |
|
Klass : {Estimator, Display, None} |
|
estimator object or None |
|
method : str |
|
if estimator is not None, either the method name or None. |
|
|
|
Returns |
|
------- |
|
str |
|
String representation of the error. |
|
""" |
|
if method is None: |
|
if hasattr(Klass, "__init__"): |
|
method = "__init__" |
|
elif Klass is None: |
|
raise ValueError("At least one of Klass, method should be provided") |
|
else: |
|
raise NotImplementedError |
|
|
|
if Klass is not None: |
|
obj = getattr(Klass, method) |
|
try: |
|
obj_signature = str(signature(obj)) |
|
except TypeError: |
|
|
|
obj_signature = ( |
|
"\nParsing of the method signature failed, " |
|
"possibly because this is a property." |
|
) |
|
|
|
obj_name = Klass.__name__ + "." + method |
|
else: |
|
obj_signature = "" |
|
obj_name = method |
|
|
|
msg = "\n\n" + "\n\n".join( |
|
[ |
|
str(res["file"]), |
|
obj_name + obj_signature, |
|
res["docstring"], |
|
"# Errors", |
|
"\n".join( |
|
" - {}: {}".format(code, message) for code, message in res["errors"] |
|
), |
|
] |
|
) |
|
return msg |
|
|
|
|
|
@pytest.mark.parametrize("function_name", get_all_functions_names()) |
|
def test_function_docstring(function_name, request): |
|
"""Check function docstrings using numpydoc.""" |
|
res = numpydoc_validation.validate(function_name) |
|
|
|
res["errors"] = list(filter_errors(res["errors"], method="function")) |
|
|
|
if res["errors"]: |
|
msg = repr_errors(res, method=f"Tested function: {function_name}") |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
@pytest.mark.parametrize("Klass, method", get_all_methods()) |
|
def test_docstring(Klass, method, request): |
|
base_import_path = Klass.__module__ |
|
import_path = [base_import_path, Klass.__name__] |
|
if method is not None: |
|
import_path.append(method) |
|
|
|
import_path = ".".join(import_path) |
|
|
|
res = numpydoc_validation.validate(import_path) |
|
|
|
res["errors"] = list(filter_errors(res["errors"], method, Klass=Klass)) |
|
|
|
if res["errors"]: |
|
msg = repr_errors(res, Klass, method) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
import sys |
|
|
|
parser = argparse.ArgumentParser(description="Validate docstring with numpydoc.") |
|
parser.add_argument("import_path", help="Import path to validate") |
|
|
|
args = parser.parse_args() |
|
|
|
res = numpydoc_validation.validate(args.import_path) |
|
|
|
import_path_sections = args.import_path.split(".") |
|
|
|
|
|
|
|
|
|
if len(import_path_sections) >= 2 and re.match( |
|
r"(?:[A-Z][a-z]*)+", import_path_sections[-2] |
|
): |
|
method = import_path_sections[-1] |
|
else: |
|
method = None |
|
|
|
res["errors"] = list(filter_errors(res["errors"], method)) |
|
|
|
if res["errors"]: |
|
msg = repr_errors(res, method=args.import_path) |
|
|
|
print(msg) |
|
sys.exit(1) |
|
else: |
|
print("All docstring checks passed for {}!".format(args.import_path)) |
|
|