mirror of
https://github.com/huggingface/diffusers.git
synced 2026-06-02 00:01:34 +08:00
* feat: allow docstring checker to fix unused args in docstrings. * feat: check for returns, too. * [docs] add missing Returns sections to forward/__call__ docstrings (#13830) Adds a Returns: section to the 43 model forward() and pipeline __call__() methods flagged by utils/check_forward_call_docstrings.py, which requires a Returns: section whenever the method has a non-None return annotation. Descriptions reflect the actual return statements (Output dataclass when return_dict=True, plain tuple otherwise; bare tensors / lists where applicable) and reuse each file's existing doc-builder link form. Also reformats a malformed single-line Returns: in pipeline_aura_flow.py that the check could not detect. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com> * Apply style fixes --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
471 lines
16 KiB
Python
471 lines
16 KiB
Python
# coding=utf-8
|
|
# Copyright 2026 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Check that arguments of ``forward()`` (for models) and ``__call__()`` (for
|
|
pipelines) match the method's docstring exactly:
|
|
|
|
* every signature argument has an entry in the ``Args:`` /
|
|
``Arguments:`` / ``Parameters:`` section,
|
|
* every documented argument still exists in the signature
|
|
(stale entries from removed/renamed args are flagged), and
|
|
* when the method has a non-``None`` return annotation, the docstring has
|
|
a ``Returns:`` / ``Return:`` / ``Yields:`` section.
|
|
|
|
A "main" class is detected via its base classes — models inherit from
|
|
``ModelMixin`` and pipelines inherit from ``DiffusionPipeline``. Only methods
|
|
defined directly on the class are checked; inherited methods are checked when
|
|
the parent class is visited.
|
|
|
|
Run from the repository root:
|
|
|
|
python utils/check_forward_call_docstrings.py
|
|
|
|
Optionally restrict to specific files:
|
|
|
|
python utils/check_forward_call_docstrings.py --paths src/diffusers/models/transformers/transformer_flux.py
|
|
|
|
Auto-fix stale (documented-but-removed) entries — missing entries are never
|
|
auto-added (no placeholders), only stale ones are removed:
|
|
|
|
python utils/check_forward_call_docstrings.py --fix
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import ast
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
MODELS_DIR = REPO_ROOT / "src" / "diffusers" / "models"
|
|
PIPELINES_DIR = REPO_ROOT / "src" / "diffusers" / "pipelines"
|
|
|
|
MODEL_BASE = "ModelMixin"
|
|
PIPELINE_BASE = "DiffusionPipeline"
|
|
|
|
SECTION_HEADERS = {
|
|
"Args:",
|
|
"Arguments:",
|
|
"Parameters:",
|
|
"Returns:",
|
|
"Return:",
|
|
"Yields:",
|
|
"Raises:",
|
|
"Examples:",
|
|
"Example:",
|
|
"Note:",
|
|
"Notes:",
|
|
"References:",
|
|
"See Also:",
|
|
}
|
|
|
|
# `name (...)` or `name:` at the start of a (stripped) line.
|
|
_ARG_HEADER_RE = re.compile(r"^([A-Za-z_]\w*)\s*[(:]")
|
|
|
|
# Pairs of (class_name, method_name) whose missing-arg errors should be
|
|
# suppressed. Use sparingly — prefer fixing the docstring.
|
|
IGNORE: set[tuple[str, str]] = set()
|
|
|
|
|
|
def _base_class_names(class_def: ast.ClassDef) -> set[str]:
|
|
"""Return the textual names of base classes (best-effort)."""
|
|
names: set[str] = set()
|
|
for base in class_def.bases:
|
|
if isinstance(base, ast.Name):
|
|
names.add(base.id)
|
|
elif isinstance(base, ast.Attribute):
|
|
names.add(base.attr)
|
|
return names
|
|
|
|
|
|
def _find_method(class_def: ast.ClassDef, method_name: str) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
|
|
for node in class_def.body:
|
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == method_name:
|
|
return node
|
|
return None
|
|
|
|
|
|
def _docstring_node(func: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.Expr | None:
|
|
if (
|
|
func.body
|
|
and isinstance(func.body[0], ast.Expr)
|
|
and isinstance(func.body[0].value, ast.Constant)
|
|
and isinstance(func.body[0].value.value, str)
|
|
):
|
|
return func.body[0]
|
|
return None
|
|
|
|
|
|
def _signature_arg_names(func: ast.FunctionDef | ast.AsyncFunctionDef) -> list[str]:
|
|
args = func.args
|
|
collected: list[str] = []
|
|
for a in (*args.posonlyargs, *args.args, *args.kwonlyargs):
|
|
if a.arg == "self" or a.arg == "cls":
|
|
continue
|
|
collected.append(a.arg)
|
|
return collected
|
|
|
|
|
|
def _has_meaningful_return(func: ast.FunctionDef | ast.AsyncFunctionDef) -> bool:
|
|
"""True iff the method has a return annotation other than ``None`` or ``NoReturn``."""
|
|
ret = func.returns
|
|
if ret is None: # no annotation at all
|
|
return False
|
|
if isinstance(ret, ast.Constant) and ret.value is None: # `-> None`
|
|
return False
|
|
# `-> NoReturn` or `-> typing.NoReturn`
|
|
if isinstance(ret, ast.Name) and ret.id == "NoReturn":
|
|
return False
|
|
if isinstance(ret, ast.Attribute) and ret.attr == "NoReturn":
|
|
return False
|
|
return True
|
|
|
|
|
|
def _has_returns_section(docstring: str | None) -> bool:
|
|
if not docstring:
|
|
return False
|
|
for line in docstring.splitlines():
|
|
if line.strip() in {"Returns:", "Return:", "Yields:", "Yield:"}:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _extract_documented_args(docstring: str | None) -> set[str]:
|
|
"""Extract argument names listed in an Args/Arguments/Parameters section.
|
|
|
|
Assumes the docstring has been cleaned (``inspect.cleandoc`` / ``ast.get_docstring``).
|
|
The section ends at the next blank-line-followed-by-section-header or at the
|
|
end of the docstring.
|
|
"""
|
|
if not docstring:
|
|
return set()
|
|
|
|
lines = docstring.splitlines()
|
|
|
|
# Locate the Args/Arguments/Parameters header.
|
|
start = None
|
|
header_indent = 0
|
|
for i, line in enumerate(lines):
|
|
stripped = line.strip()
|
|
if stripped in {"Args:", "Arguments:", "Parameters:"}:
|
|
start = i + 1
|
|
header_indent = len(line) - len(line.lstrip())
|
|
break
|
|
if start is None:
|
|
return set()
|
|
|
|
# First non-empty line after the header sets the per-entry indent level.
|
|
entry_indent: int | None = None
|
|
documented: set[str] = set()
|
|
|
|
for line in lines[start:]:
|
|
stripped = line.strip()
|
|
if not stripped:
|
|
continue
|
|
indent = len(line) - len(line.lstrip())
|
|
|
|
# A new section at the same (or shallower) indent ends the args block.
|
|
if indent <= header_indent and stripped in SECTION_HEADERS:
|
|
break
|
|
|
|
if entry_indent is None:
|
|
entry_indent = indent
|
|
|
|
# Only lines at the entry indent are candidate arg headers; deeper
|
|
# indents are descriptions/continuations.
|
|
if indent != entry_indent:
|
|
continue
|
|
|
|
match = _ARG_HEADER_RE.match(stripped)
|
|
if match:
|
|
documented.add(match.group(1))
|
|
|
|
return documented
|
|
|
|
|
|
def check_file(path: Path, kind: str) -> list[str]:
|
|
"""Return a list of human-readable error strings for ``path``."""
|
|
method_name = "forward" if kind == "model" else "__call__"
|
|
base_class = MODEL_BASE if kind == "model" else PIPELINE_BASE
|
|
|
|
try:
|
|
tree = ast.parse(path.read_text(encoding="utf-8"))
|
|
except (SyntaxError, UnicodeDecodeError):
|
|
return []
|
|
|
|
errors: list[str] = []
|
|
rel = path.relative_to(REPO_ROOT)
|
|
|
|
for node in ast.walk(tree):
|
|
if not isinstance(node, ast.ClassDef):
|
|
continue
|
|
if base_class not in _base_class_names(node):
|
|
continue
|
|
if (node.name, method_name) in IGNORE:
|
|
continue
|
|
method = _find_method(node, method_name)
|
|
if method is None:
|
|
continue
|
|
sig_args = _signature_arg_names(method)
|
|
sig_set = set(sig_args)
|
|
docstring_text = ast.get_docstring(method)
|
|
documented = _extract_documented_args(docstring_text)
|
|
missing = [a for a in sig_args if a not in documented]
|
|
stale = sorted(documented - sig_set)
|
|
if missing:
|
|
errors.append(
|
|
f"{rel}:{method.lineno}: {node.name}.{method_name} is missing "
|
|
f"docstring entries for: {', '.join(missing)}"
|
|
)
|
|
if stale:
|
|
errors.append(
|
|
f"{rel}:{method.lineno}: {node.name}.{method_name} documents "
|
|
f"argument(s) not in the signature: {', '.join(stale)}"
|
|
)
|
|
if _has_meaningful_return(method) and not _has_returns_section(docstring_text):
|
|
return_repr = ast.unparse(method.returns)
|
|
ds = _docstring_node(method)
|
|
if ds is None:
|
|
where = " (method has no docstring)"
|
|
else:
|
|
where = f' (add it just above the closing """ on line {ds.end_lineno})'
|
|
errors.append(
|
|
f"{rel}:{method.lineno}: {node.name}.{method_name} returns "
|
|
f"`{return_repr}` but the docstring has no Returns: section{where}"
|
|
)
|
|
return errors
|
|
|
|
|
|
def fix_file(path: Path, kind: str) -> list[str]:
|
|
"""Remove stale arg entries (documented but not in signature) in-place.
|
|
|
|
Missing-in-signature → docstring entries are NOT added (no placeholders).
|
|
Returns a list of ``"ClassName.method: removed name1, name2"`` strings
|
|
describing what was removed.
|
|
"""
|
|
method_name = "forward" if kind == "model" else "__call__"
|
|
base_class = MODEL_BASE if kind == "model" else PIPELINE_BASE
|
|
|
|
source = path.read_text(encoding="utf-8")
|
|
try:
|
|
tree = ast.parse(source)
|
|
except (SyntaxError, UnicodeDecodeError):
|
|
return []
|
|
|
|
lines = source.splitlines(keepends=True)
|
|
# (start_idx, end_idx_exclusive) ranges of lines to drop.
|
|
deletions: list[tuple[int, int]] = []
|
|
summaries: list[str] = []
|
|
|
|
for node in ast.walk(tree):
|
|
if not isinstance(node, ast.ClassDef):
|
|
continue
|
|
if base_class not in _base_class_names(node):
|
|
continue
|
|
method = _find_method(node, method_name)
|
|
if method is None:
|
|
continue
|
|
# Method must start with a string docstring expression.
|
|
if not (
|
|
method.body
|
|
and isinstance(method.body[0], ast.Expr)
|
|
and isinstance(method.body[0].value, ast.Constant)
|
|
and isinstance(method.body[0].value.value, str)
|
|
):
|
|
continue
|
|
|
|
sig_set = set(_signature_arg_names(method))
|
|
documented = _extract_documented_args(ast.get_docstring(method))
|
|
stale = documented - sig_set
|
|
if not stale:
|
|
continue
|
|
|
|
docstring_expr = method.body[0]
|
|
doc_start = docstring_expr.lineno - 1 # 0-indexed
|
|
doc_end = docstring_expr.end_lineno - 1 # 0-indexed, inclusive
|
|
|
|
# Locate the Args/Arguments/Parameters header in raw source.
|
|
args_idx: int | None = None
|
|
header_indent = 0
|
|
for i in range(doc_start, doc_end + 1):
|
|
stripped = lines[i].strip()
|
|
if stripped in {"Args:", "Arguments:", "Parameters:"}:
|
|
args_idx = i
|
|
header_indent = len(lines[i]) - len(lines[i].lstrip())
|
|
break
|
|
if args_idx is None:
|
|
continue
|
|
|
|
# First non-empty line after the header sets the per-entry indent.
|
|
entry_indent: int | None = None
|
|
for i in range(args_idx + 1, doc_end + 1):
|
|
stripped = lines[i].strip()
|
|
if not stripped:
|
|
continue
|
|
entry_indent = len(lines[i]) - len(lines[i].lstrip())
|
|
break
|
|
if entry_indent is None or entry_indent <= header_indent:
|
|
continue
|
|
|
|
# Walk entries; each entry spans from its header line up to (but not
|
|
# including) the next entry header / section header / end of docstring.
|
|
current_name: str | None = None
|
|
current_start: int = -1
|
|
end_of_args: int | None = None
|
|
|
|
for i in range(args_idx + 1, doc_end + 1):
|
|
line = lines[i]
|
|
stripped = line.strip()
|
|
if not stripped:
|
|
continue
|
|
indent = len(line) - len(line.lstrip())
|
|
|
|
if indent <= header_indent and stripped in SECTION_HEADERS:
|
|
end_of_args = i
|
|
break
|
|
|
|
if indent == entry_indent:
|
|
m = _ARG_HEADER_RE.match(stripped)
|
|
if m:
|
|
if current_name in stale:
|
|
deletions.append((current_start, i))
|
|
current_name = m.group(1)
|
|
current_start = i
|
|
|
|
if current_name in stale:
|
|
end = end_of_args if end_of_args is not None else doc_end
|
|
# Trailing blank lines belong to inter-section spacing (or the
|
|
# blank line before the closing """), not to this entry.
|
|
while end > current_start + 1 and not lines[end - 1].strip():
|
|
end -= 1
|
|
deletions.append((current_start, end))
|
|
|
|
summaries.append(f"{node.name}.{method_name}: removed {', '.join(sorted(stale))}")
|
|
|
|
if not deletions:
|
|
return []
|
|
|
|
deletions.sort()
|
|
new_lines = list(lines)
|
|
for start, end in reversed(deletions):
|
|
del new_lines[start:end]
|
|
path.write_text("".join(new_lines), encoding="utf-8")
|
|
return summaries
|
|
|
|
|
|
def _kind_for_path(path: Path) -> str | None:
|
|
parts = path.resolve().parts
|
|
if "pipelines" in parts:
|
|
return "pipeline"
|
|
if "models" in parts:
|
|
return "model"
|
|
return None
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument(
|
|
"--paths",
|
|
nargs="+",
|
|
help="Specific files to check (defaults to all of src/diffusers/{models,pipelines}).",
|
|
)
|
|
parser.add_argument(
|
|
"--limit",
|
|
type=int,
|
|
default=None,
|
|
help=(
|
|
"Debug helper: when --paths is not given, only check the first N files "
|
|
"(in sorted order) from each of models/ and pipelines/."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--fix",
|
|
action="store_true",
|
|
help=(
|
|
"Remove stale (documented-but-not-in-signature) argument entries from "
|
|
"docstrings in-place. Missing-in-docstring entries are NOT auto-added "
|
|
"(no placeholders) and will still be reported."
|
|
),
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
targets: list[tuple[Path, str]] = []
|
|
if args.paths:
|
|
for raw in args.paths:
|
|
p = Path(raw).resolve()
|
|
kind = _kind_for_path(p)
|
|
if kind is None:
|
|
print(f"Skipping {raw}: not under models/ or pipelines/.", file=sys.stderr)
|
|
continue
|
|
targets.append((p, kind))
|
|
else:
|
|
model_files = sorted(MODELS_DIR.rglob("*.py"))
|
|
pipeline_files = sorted(PIPELINES_DIR.rglob("*.py"))
|
|
if args.limit is not None:
|
|
if args.limit < 0:
|
|
parser.error("--limit must be non-negative")
|
|
model_files = model_files[: args.limit]
|
|
pipeline_files = pipeline_files[: args.limit]
|
|
print(
|
|
f"--limit {args.limit}: checking {len(model_files)} model file(s) "
|
|
f"and {len(pipeline_files)} pipeline file(s).",
|
|
file=sys.stderr,
|
|
)
|
|
for p in model_files:
|
|
targets.append((p, "model"))
|
|
for p in pipeline_files:
|
|
targets.append((p, "pipeline"))
|
|
|
|
if args.fix:
|
|
fix_summaries: list[str] = []
|
|
for path, kind in targets:
|
|
for summary in fix_file(path, kind):
|
|
fix_summaries.append(f"{path.relative_to(REPO_ROOT)}: {summary}")
|
|
if fix_summaries:
|
|
print("Removed stale docstring entries:")
|
|
print("\n".join(f" {s}" for s in fix_summaries))
|
|
else:
|
|
print("No stale docstring entries to remove.")
|
|
|
|
all_errors: list[str] = []
|
|
for path, kind in targets:
|
|
all_errors.extend(check_file(path, kind))
|
|
|
|
if all_errors:
|
|
print("\n".join(all_errors))
|
|
print(
|
|
f"\nFound {len(all_errors)} docstring/signature mismatch(es).",
|
|
file=sys.stderr,
|
|
)
|
|
if not args.fix and any("documents argument(s) not in the signature" in e for e in all_errors):
|
|
print(
|
|
"Hint: run `python utils/check_forward_call_docstrings.py --fix` "
|
|
"to remove the stale argument entries flagged above. "
|
|
"(Missing-in-docstring entries must be added manually — the tool "
|
|
"never inserts placeholders.)",
|
|
file=sys.stderr,
|
|
)
|
|
return 1
|
|
|
|
print("All forward/__call__ arguments are documented.")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|