Spaces:
Running
on
Zero
Running
on
Zero
# coding=utf-8 | |
# Copyright 2024 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
""" | |
Utility that checks that modules like attention processors are listed in the documentation file. | |
```bash | |
python utils/check_support_list.py | |
``` | |
It has no auto-fix mode. | |
""" | |
import os | |
import re | |
# All paths are set with the intent that you run this script from the root of the repo | |
REPO_PATH = "." | |
def read_documented_classes(doc_path, autodoc_regex=r"\[\[autodoc\]\]\s([^\n]+)"): | |
""" | |
Reads documented classes from a doc file using a regex to find lines like [[autodoc]] my.module.Class. | |
Returns a list of documented class names (just the class name portion). | |
""" | |
with open(os.path.join(REPO_PATH, doc_path), "r") as f: | |
doctext = f.read() | |
matches = re.findall(autodoc_regex, doctext) | |
return [match.split(".")[-1] for match in matches] | |
def read_source_classes(src_path, class_regex, exclude_conditions=None): | |
""" | |
Reads class names from a source file using a regex that captures class definitions. | |
Optionally exclude classes based on a list of conditions (functions that take class name and return bool). | |
""" | |
if exclude_conditions is None: | |
exclude_conditions = [] | |
with open(os.path.join(REPO_PATH, src_path), "r") as f: | |
doctext = f.read() | |
classes = re.findall(class_regex, doctext) | |
# Filter out classes that meet any of the exclude conditions | |
filtered_classes = [c for c in classes if not any(cond(c) for cond in exclude_conditions)] | |
return filtered_classes | |
def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_conditions=None): | |
""" | |
Generic function to check if all classes defined in `src_path` are documented in `doc_path`. | |
Returns a set of undocumented class names. | |
""" | |
documented = set(read_documented_classes(doc_path, doc_regex)) | |
source_classes = set(read_source_classes(src_path, src_regex, exclude_conditions=exclude_conditions)) | |
# Find which classes in source are not documented in a deterministic way. | |
undocumented = sorted(source_classes - documented) | |
return undocumented | |
if __name__ == "__main__": | |
# Define the checks we need to perform | |
checks = { | |
"Attention Processors": { | |
"doc_path": "docs/source/en/api/attnprocessor.md", | |
"src_path": "src/diffusers/models/attention_processor.py", | |
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | |
"src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]", | |
"exclude_conditions": [lambda c: "LoRA" in c, lambda c: c == "Attention"], | |
}, | |
"Image Processors": { | |
"doc_path": "docs/source/en/api/image_processor.md", | |
"src_path": "src/diffusers/image_processor.py", | |
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | |
"src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]", | |
}, | |
"Activations": { | |
"doc_path": "docs/source/en/api/activations.md", | |
"src_path": "src/diffusers/models/activations.py", | |
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | |
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", | |
}, | |
"Normalizations": { | |
"doc_path": "docs/source/en/api/normalization.md", | |
"src_path": "src/diffusers/models/normalization.py", | |
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | |
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", | |
"exclude_conditions": [ | |
# Exclude LayerNorm as it's an intentional exception | |
lambda c: c == "LayerNorm" | |
], | |
}, | |
"LoRA Mixins": { | |
"doc_path": "docs/source/en/api/loaders/lora.md", | |
"src_path": "src/diffusers/loaders/lora_pipeline.py", | |
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", | |
"src_regex": r"class\s+(\w+LoraLoaderMixin(?:\d*_?\d*))[:(]", | |
}, | |
} | |
missing_items = {} | |
for category, params in checks.items(): | |
undocumented = check_documentation( | |
doc_path=params["doc_path"], | |
src_path=params["src_path"], | |
doc_regex=params["doc_regex"], | |
src_regex=params["src_regex"], | |
exclude_conditions=params.get("exclude_conditions"), | |
) | |
if undocumented: | |
missing_items[category] = undocumented | |
# If we have any missing items, raise a single combined error | |
if missing_items: | |
error_msg = ["Some classes are not documented properly:\n"] | |
for category, classes in missing_items.items(): | |
error_msg.append(f"- {category}: {', '.join(sorted(classes))}") | |
raise ValueError("\n".join(error_msg)) | |