Files
diffusers/utils/check_forward_call_docstrings.py
Sayak Paul f502538134 [WIP] chore: add utilities to check if call/forward methods are documented. (#13758)
* chore: add utilities to check if call/forward methods are documented.

* Fix missing forward/__call__ docstring entries (#13769)

add missing

* style.

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2026-05-21 08:00:22 +05:30

274 lines
8.5 KiB
Python

# coding=utf-8
# Copyright 2026 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Check that arguments of ``forward()`` (for models) and ``__call__()`` (for
pipelines) match the method's docstring exactly:
* every signature argument has an entry in the ``Args:`` /
``Arguments:`` / ``Parameters:`` section, and
* every documented argument still exists in the signature
(stale entries from removed/renamed args are flagged).
A "main" class is detected via its base classes — models inherit from
``ModelMixin`` and pipelines inherit from ``DiffusionPipeline``. Only methods
defined directly on the class are checked; inherited methods are checked when
the parent class is visited.
Run from the repository root:
python utils/check_forward_call_docstrings.py
Optionally restrict to specific files:
python utils/check_forward_call_docstrings.py --paths src/diffusers/models/transformers/transformer_flux.py
"""
from __future__ import annotations
import argparse
import ast
import re
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
MODELS_DIR = REPO_ROOT / "src" / "diffusers" / "models"
PIPELINES_DIR = REPO_ROOT / "src" / "diffusers" / "pipelines"
MODEL_BASE = "ModelMixin"
PIPELINE_BASE = "DiffusionPipeline"
SECTION_HEADERS = {
"Args:",
"Arguments:",
"Parameters:",
"Returns:",
"Return:",
"Yields:",
"Raises:",
"Examples:",
"Example:",
"Note:",
"Notes:",
"References:",
"See Also:",
}
# `name (...)` or `name:` at the start of a (stripped) line.
_ARG_HEADER_RE = re.compile(r"^([A-Za-z_]\w*)\s*[(:]")
# Pairs of (class_name, method_name) whose missing-arg errors should be
# suppressed. Use sparingly — prefer fixing the docstring.
IGNORE: set[tuple[str, str]] = set()
def _base_class_names(class_def: ast.ClassDef) -> set[str]:
"""Return the textual names of base classes (best-effort)."""
names: set[str] = set()
for base in class_def.bases:
if isinstance(base, ast.Name):
names.add(base.id)
elif isinstance(base, ast.Attribute):
names.add(base.attr)
return names
def _find_method(class_def: ast.ClassDef, method_name: str) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
for node in class_def.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == method_name:
return node
return None
def _signature_arg_names(func: ast.FunctionDef | ast.AsyncFunctionDef) -> list[str]:
args = func.args
collected: list[str] = []
for a in (*args.posonlyargs, *args.args, *args.kwonlyargs):
if a.arg == "self" or a.arg == "cls":
continue
collected.append(a.arg)
return collected
def _extract_documented_args(docstring: str | None) -> set[str]:
"""Extract argument names listed in an Args/Arguments/Parameters section.
Assumes the docstring has been cleaned (``inspect.cleandoc`` / ``ast.get_docstring``).
The section ends at the next blank-line-followed-by-section-header or at the
end of the docstring.
"""
if not docstring:
return set()
lines = docstring.splitlines()
# Locate the Args/Arguments/Parameters header.
start = None
header_indent = 0
for i, line in enumerate(lines):
stripped = line.strip()
if stripped in {"Args:", "Arguments:", "Parameters:"}:
start = i + 1
header_indent = len(line) - len(line.lstrip())
break
if start is None:
return set()
# First non-empty line after the header sets the per-entry indent level.
entry_indent: int | None = None
documented: set[str] = set()
for line in lines[start:]:
stripped = line.strip()
if not stripped:
continue
indent = len(line) - len(line.lstrip())
# A new section at the same (or shallower) indent ends the args block.
if indent <= header_indent and stripped in SECTION_HEADERS:
break
if entry_indent is None:
entry_indent = indent
# Only lines at the entry indent are candidate arg headers; deeper
# indents are descriptions/continuations.
if indent != entry_indent:
continue
match = _ARG_HEADER_RE.match(stripped)
if match:
documented.add(match.group(1))
return documented
def check_file(path: Path, kind: str) -> list[str]:
"""Return a list of human-readable error strings for ``path``."""
method_name = "forward" if kind == "model" else "__call__"
base_class = MODEL_BASE if kind == "model" else PIPELINE_BASE
try:
tree = ast.parse(path.read_text(encoding="utf-8"))
except (SyntaxError, UnicodeDecodeError):
return []
errors: list[str] = []
rel = path.relative_to(REPO_ROOT)
for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef):
continue
if base_class not in _base_class_names(node):
continue
if (node.name, method_name) in IGNORE:
continue
method = _find_method(node, method_name)
if method is None:
continue
sig_args = _signature_arg_names(method)
if not sig_args:
continue
sig_set = set(sig_args)
documented = _extract_documented_args(ast.get_docstring(method))
missing = [a for a in sig_args if a not in documented]
stale = sorted(documented - sig_set)
if missing:
errors.append(
f"{rel}:{method.lineno}: {node.name}.{method_name} is missing "
f"docstring entries for: {', '.join(missing)}"
)
if stale:
errors.append(
f"{rel}:{method.lineno}: {node.name}.{method_name} documents "
f"argument(s) not in the signature: {', '.join(stale)}"
)
return errors
def _kind_for_path(path: Path) -> str | None:
parts = path.resolve().parts
if "pipelines" in parts:
return "pipeline"
if "models" in parts:
return "model"
return None
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--paths",
nargs="+",
help="Specific files to check (defaults to all of src/diffusers/{models,pipelines}).",
)
parser.add_argument(
"--limit",
type=int,
default=None,
help=(
"Debug helper: when --paths is not given, only check the first N files "
"(in sorted order) from each of models/ and pipelines/."
),
)
args = parser.parse_args()
targets: list[tuple[Path, str]] = []
if args.paths:
for raw in args.paths:
p = Path(raw).resolve()
kind = _kind_for_path(p)
if kind is None:
print(f"Skipping {raw}: not under models/ or pipelines/.", file=sys.stderr)
continue
targets.append((p, kind))
else:
model_files = sorted(MODELS_DIR.rglob("*.py"))
pipeline_files = sorted(PIPELINES_DIR.rglob("*.py"))
if args.limit is not None:
if args.limit < 0:
parser.error("--limit must be non-negative")
model_files = model_files[: args.limit]
pipeline_files = pipeline_files[: args.limit]
print(
f"--limit {args.limit}: checking {len(model_files)} model file(s) "
f"and {len(pipeline_files)} pipeline file(s).",
file=sys.stderr,
)
for p in model_files:
targets.append((p, "model"))
for p in pipeline_files:
targets.append((p, "pipeline"))
all_errors: list[str] = []
for path, kind in targets:
all_errors.extend(check_file(path, kind))
if all_errors:
print("\n".join(all_errors))
print(
f"\nFound {len(all_errors)} docstring/signature mismatch(es).",
file=sys.stderr,
)
return 1
print("All forward/__call__ arguments are documented.")
return 0
if __name__ == "__main__":
sys.exit(main())