Files
diffusers/utils/check_forward_call_docstrings.py
Sayak Paul e87b2a7ad8 [docs] Follow ups for consistent forward docstrings (#13779)
* feat: allow docstring checker to fix unused args in docstrings.

* feat: check for returns, too.

* [docs] add missing Returns sections to forward/__call__ docstrings (#13830)

Adds a Returns: section to the 43 model forward() and pipeline __call__()
methods flagged by utils/check_forward_call_docstrings.py, which requires a
Returns: section whenever the method has a non-None return annotation.

Descriptions reflect the actual return statements (Output dataclass when
return_dict=True, plain tuple otherwise; bare tensors / lists where
applicable) and reuse each file's existing doc-builder link form. Also
reformats a malformed single-line Returns: in pipeline_aura_flow.py that the
check could not detect.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>

* Apply style fixes

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-05-29 09:07:10 +05:30

471 lines
16 KiB
Python

# coding=utf-8
# Copyright 2026 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Check that arguments of ``forward()`` (for models) and ``__call__()`` (for
pipelines) match the method's docstring exactly:
* every signature argument has an entry in the ``Args:`` /
``Arguments:`` / ``Parameters:`` section,
* every documented argument still exists in the signature
(stale entries from removed/renamed args are flagged), and
* when the method has a non-``None`` return annotation, the docstring has
a ``Returns:`` / ``Return:`` / ``Yields:`` section.
A "main" class is detected via its base classes — models inherit from
``ModelMixin`` and pipelines inherit from ``DiffusionPipeline``. Only methods
defined directly on the class are checked; inherited methods are checked when
the parent class is visited.
Run from the repository root:
python utils/check_forward_call_docstrings.py
Optionally restrict to specific files:
python utils/check_forward_call_docstrings.py --paths src/diffusers/models/transformers/transformer_flux.py
Auto-fix stale (documented-but-removed) entries — missing entries are never
auto-added (no placeholders), only stale ones are removed:
python utils/check_forward_call_docstrings.py --fix
"""
from __future__ import annotations
import argparse
import ast
import re
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
MODELS_DIR = REPO_ROOT / "src" / "diffusers" / "models"
PIPELINES_DIR = REPO_ROOT / "src" / "diffusers" / "pipelines"
MODEL_BASE = "ModelMixin"
PIPELINE_BASE = "DiffusionPipeline"
SECTION_HEADERS = {
"Args:",
"Arguments:",
"Parameters:",
"Returns:",
"Return:",
"Yields:",
"Raises:",
"Examples:",
"Example:",
"Note:",
"Notes:",
"References:",
"See Also:",
}
# `name (...)` or `name:` at the start of a (stripped) line.
_ARG_HEADER_RE = re.compile(r"^([A-Za-z_]\w*)\s*[(:]")
# Pairs of (class_name, method_name) whose missing-arg errors should be
# suppressed. Use sparingly — prefer fixing the docstring.
IGNORE: set[tuple[str, str]] = set()
def _base_class_names(class_def: ast.ClassDef) -> set[str]:
"""Return the textual names of base classes (best-effort)."""
names: set[str] = set()
for base in class_def.bases:
if isinstance(base, ast.Name):
names.add(base.id)
elif isinstance(base, ast.Attribute):
names.add(base.attr)
return names
def _find_method(class_def: ast.ClassDef, method_name: str) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
for node in class_def.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == method_name:
return node
return None
def _docstring_node(func: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.Expr | None:
if (
func.body
and isinstance(func.body[0], ast.Expr)
and isinstance(func.body[0].value, ast.Constant)
and isinstance(func.body[0].value.value, str)
):
return func.body[0]
return None
def _signature_arg_names(func: ast.FunctionDef | ast.AsyncFunctionDef) -> list[str]:
args = func.args
collected: list[str] = []
for a in (*args.posonlyargs, *args.args, *args.kwonlyargs):
if a.arg == "self" or a.arg == "cls":
continue
collected.append(a.arg)
return collected
def _has_meaningful_return(func: ast.FunctionDef | ast.AsyncFunctionDef) -> bool:
"""True iff the method has a return annotation other than ``None`` or ``NoReturn``."""
ret = func.returns
if ret is None: # no annotation at all
return False
if isinstance(ret, ast.Constant) and ret.value is None: # `-> None`
return False
# `-> NoReturn` or `-> typing.NoReturn`
if isinstance(ret, ast.Name) and ret.id == "NoReturn":
return False
if isinstance(ret, ast.Attribute) and ret.attr == "NoReturn":
return False
return True
def _has_returns_section(docstring: str | None) -> bool:
if not docstring:
return False
for line in docstring.splitlines():
if line.strip() in {"Returns:", "Return:", "Yields:", "Yield:"}:
return True
return False
def _extract_documented_args(docstring: str | None) -> set[str]:
"""Extract argument names listed in an Args/Arguments/Parameters section.
Assumes the docstring has been cleaned (``inspect.cleandoc`` / ``ast.get_docstring``).
The section ends at the next blank-line-followed-by-section-header or at the
end of the docstring.
"""
if not docstring:
return set()
lines = docstring.splitlines()
# Locate the Args/Arguments/Parameters header.
start = None
header_indent = 0
for i, line in enumerate(lines):
stripped = line.strip()
if stripped in {"Args:", "Arguments:", "Parameters:"}:
start = i + 1
header_indent = len(line) - len(line.lstrip())
break
if start is None:
return set()
# First non-empty line after the header sets the per-entry indent level.
entry_indent: int | None = None
documented: set[str] = set()
for line in lines[start:]:
stripped = line.strip()
if not stripped:
continue
indent = len(line) - len(line.lstrip())
# A new section at the same (or shallower) indent ends the args block.
if indent <= header_indent and stripped in SECTION_HEADERS:
break
if entry_indent is None:
entry_indent = indent
# Only lines at the entry indent are candidate arg headers; deeper
# indents are descriptions/continuations.
if indent != entry_indent:
continue
match = _ARG_HEADER_RE.match(stripped)
if match:
documented.add(match.group(1))
return documented
def check_file(path: Path, kind: str) -> list[str]:
"""Return a list of human-readable error strings for ``path``."""
method_name = "forward" if kind == "model" else "__call__"
base_class = MODEL_BASE if kind == "model" else PIPELINE_BASE
try:
tree = ast.parse(path.read_text(encoding="utf-8"))
except (SyntaxError, UnicodeDecodeError):
return []
errors: list[str] = []
rel = path.relative_to(REPO_ROOT)
for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef):
continue
if base_class not in _base_class_names(node):
continue
if (node.name, method_name) in IGNORE:
continue
method = _find_method(node, method_name)
if method is None:
continue
sig_args = _signature_arg_names(method)
sig_set = set(sig_args)
docstring_text = ast.get_docstring(method)
documented = _extract_documented_args(docstring_text)
missing = [a for a in sig_args if a not in documented]
stale = sorted(documented - sig_set)
if missing:
errors.append(
f"{rel}:{method.lineno}: {node.name}.{method_name} is missing "
f"docstring entries for: {', '.join(missing)}"
)
if stale:
errors.append(
f"{rel}:{method.lineno}: {node.name}.{method_name} documents "
f"argument(s) not in the signature: {', '.join(stale)}"
)
if _has_meaningful_return(method) and not _has_returns_section(docstring_text):
return_repr = ast.unparse(method.returns)
ds = _docstring_node(method)
if ds is None:
where = " (method has no docstring)"
else:
where = f' (add it just above the closing """ on line {ds.end_lineno})'
errors.append(
f"{rel}:{method.lineno}: {node.name}.{method_name} returns "
f"`{return_repr}` but the docstring has no Returns: section{where}"
)
return errors
def fix_file(path: Path, kind: str) -> list[str]:
"""Remove stale arg entries (documented but not in signature) in-place.
Missing-in-signature → docstring entries are NOT added (no placeholders).
Returns a list of ``"ClassName.method: removed name1, name2"`` strings
describing what was removed.
"""
method_name = "forward" if kind == "model" else "__call__"
base_class = MODEL_BASE if kind == "model" else PIPELINE_BASE
source = path.read_text(encoding="utf-8")
try:
tree = ast.parse(source)
except (SyntaxError, UnicodeDecodeError):
return []
lines = source.splitlines(keepends=True)
# (start_idx, end_idx_exclusive) ranges of lines to drop.
deletions: list[tuple[int, int]] = []
summaries: list[str] = []
for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef):
continue
if base_class not in _base_class_names(node):
continue
method = _find_method(node, method_name)
if method is None:
continue
# Method must start with a string docstring expression.
if not (
method.body
and isinstance(method.body[0], ast.Expr)
and isinstance(method.body[0].value, ast.Constant)
and isinstance(method.body[0].value.value, str)
):
continue
sig_set = set(_signature_arg_names(method))
documented = _extract_documented_args(ast.get_docstring(method))
stale = documented - sig_set
if not stale:
continue
docstring_expr = method.body[0]
doc_start = docstring_expr.lineno - 1 # 0-indexed
doc_end = docstring_expr.end_lineno - 1 # 0-indexed, inclusive
# Locate the Args/Arguments/Parameters header in raw source.
args_idx: int | None = None
header_indent = 0
for i in range(doc_start, doc_end + 1):
stripped = lines[i].strip()
if stripped in {"Args:", "Arguments:", "Parameters:"}:
args_idx = i
header_indent = len(lines[i]) - len(lines[i].lstrip())
break
if args_idx is None:
continue
# First non-empty line after the header sets the per-entry indent.
entry_indent: int | None = None
for i in range(args_idx + 1, doc_end + 1):
stripped = lines[i].strip()
if not stripped:
continue
entry_indent = len(lines[i]) - len(lines[i].lstrip())
break
if entry_indent is None or entry_indent <= header_indent:
continue
# Walk entries; each entry spans from its header line up to (but not
# including) the next entry header / section header / end of docstring.
current_name: str | None = None
current_start: int = -1
end_of_args: int | None = None
for i in range(args_idx + 1, doc_end + 1):
line = lines[i]
stripped = line.strip()
if not stripped:
continue
indent = len(line) - len(line.lstrip())
if indent <= header_indent and stripped in SECTION_HEADERS:
end_of_args = i
break
if indent == entry_indent:
m = _ARG_HEADER_RE.match(stripped)
if m:
if current_name in stale:
deletions.append((current_start, i))
current_name = m.group(1)
current_start = i
if current_name in stale:
end = end_of_args if end_of_args is not None else doc_end
# Trailing blank lines belong to inter-section spacing (or the
# blank line before the closing """), not to this entry.
while end > current_start + 1 and not lines[end - 1].strip():
end -= 1
deletions.append((current_start, end))
summaries.append(f"{node.name}.{method_name}: removed {', '.join(sorted(stale))}")
if not deletions:
return []
deletions.sort()
new_lines = list(lines)
for start, end in reversed(deletions):
del new_lines[start:end]
path.write_text("".join(new_lines), encoding="utf-8")
return summaries
def _kind_for_path(path: Path) -> str | None:
parts = path.resolve().parts
if "pipelines" in parts:
return "pipeline"
if "models" in parts:
return "model"
return None
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--paths",
nargs="+",
help="Specific files to check (defaults to all of src/diffusers/{models,pipelines}).",
)
parser.add_argument(
"--limit",
type=int,
default=None,
help=(
"Debug helper: when --paths is not given, only check the first N files "
"(in sorted order) from each of models/ and pipelines/."
),
)
parser.add_argument(
"--fix",
action="store_true",
help=(
"Remove stale (documented-but-not-in-signature) argument entries from "
"docstrings in-place. Missing-in-docstring entries are NOT auto-added "
"(no placeholders) and will still be reported."
),
)
args = parser.parse_args()
targets: list[tuple[Path, str]] = []
if args.paths:
for raw in args.paths:
p = Path(raw).resolve()
kind = _kind_for_path(p)
if kind is None:
print(f"Skipping {raw}: not under models/ or pipelines/.", file=sys.stderr)
continue
targets.append((p, kind))
else:
model_files = sorted(MODELS_DIR.rglob("*.py"))
pipeline_files = sorted(PIPELINES_DIR.rglob("*.py"))
if args.limit is not None:
if args.limit < 0:
parser.error("--limit must be non-negative")
model_files = model_files[: args.limit]
pipeline_files = pipeline_files[: args.limit]
print(
f"--limit {args.limit}: checking {len(model_files)} model file(s) "
f"and {len(pipeline_files)} pipeline file(s).",
file=sys.stderr,
)
for p in model_files:
targets.append((p, "model"))
for p in pipeline_files:
targets.append((p, "pipeline"))
if args.fix:
fix_summaries: list[str] = []
for path, kind in targets:
for summary in fix_file(path, kind):
fix_summaries.append(f"{path.relative_to(REPO_ROOT)}: {summary}")
if fix_summaries:
print("Removed stale docstring entries:")
print("\n".join(f" {s}" for s in fix_summaries))
else:
print("No stale docstring entries to remove.")
all_errors: list[str] = []
for path, kind in targets:
all_errors.extend(check_file(path, kind))
if all_errors:
print("\n".join(all_errors))
print(
f"\nFound {len(all_errors)} docstring/signature mismatch(es).",
file=sys.stderr,
)
if not args.fix and any("documents argument(s) not in the signature" in e for e in all_errors):
print(
"Hint: run `python utils/check_forward_call_docstrings.py --fix` "
"to remove the stale argument entries flagged above. "
"(Missing-in-docstring entries must be added manually — the tool "
"never inserts placeholders.)",
file=sys.stderr,
)
return 1
print("All forward/__call__ arguments are documented.")
return 0
if __name__ == "__main__":
sys.exit(main())