mirror of
https://github.com/huggingface/diffusers.git
synced 2026-05-28 00:39:35 +08:00
Install transformers from main for doc and staging (#13723)
* Use Mistral3Model/Ministral3ForCausalLM
* [docs] add magcache to caching api listing (#13714)
add magcache to caching api listing
* install transformers from main
* up
* up
* up
* up[
* shorten deprecation cycle for flax.
* Revert "shorten deprecation cycle for flax."
This reverts commit 692d98db7b.
---------
Co-authored-by: Akshan Krithick <akshankrithick305@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
1
.github/workflows/build_documentation.yml
vendored
1
.github/workflows/build_documentation.yml
vendored
@@ -25,6 +25,7 @@ jobs:
|
||||
notebook_folder: diffusers_doc
|
||||
languages: en ko zh ja pt
|
||||
custom_container: diffusers/diffusers-doc-builder
|
||||
pre_command: uv pip uninstall transformers huggingface_hub && UV_PRERELEASE=allow uv pip install -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
|
||||
1
.github/workflows/build_pr_documentation.yml
vendored
1
.github/workflows/build_pr_documentation.yml
vendored
@@ -50,3 +50,4 @@ jobs:
|
||||
package: diffusers
|
||||
languages: en ko zh ja pt
|
||||
custom_container: diffusers/diffusers-doc-builder
|
||||
pre_command: uv pip uninstall transformers huggingface_hub && UV_PRERELEASE=allow uv pip install -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
2
.github/workflows/pr_tests.yml
vendored
2
.github/workflows/pr_tests.yml
vendored
@@ -194,6 +194,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall transformers huggingface_hub && UV_PRERELEASE=allow uv pip install -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall tokenizers && uv pip install "tokenizers<=0.23.0"
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -22,6 +22,7 @@ from .utils import (
|
||||
is_torchao_available,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
is_transformers_flax_compatible,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
@@ -861,7 +862,6 @@ else:
|
||||
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
|
||||
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
|
||||
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
|
||||
_import_structure["schedulers"].extend(
|
||||
[
|
||||
"FlaxDDIMScheduler",
|
||||
@@ -878,7 +878,7 @@ else:
|
||||
|
||||
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
if not (is_flax_available() and is_transformers_available() and is_transformers_flax_compatible()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
@@ -891,6 +891,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipelines"].extend(
|
||||
[
|
||||
"FlaxDiffusionPipeline",
|
||||
"FlaxStableDiffusionControlNetPipeline",
|
||||
"FlaxStableDiffusionImg2ImgPipeline",
|
||||
"FlaxStableDiffusionInpaintPipeline",
|
||||
@@ -1620,7 +1621,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .models.modeling_flax_utils import FlaxModelMixin
|
||||
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .models.vae_flax import FlaxAutoencoderKL
|
||||
from .pipelines import FlaxDiffusionPipeline
|
||||
from .schedulers import (
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDDPMScheduler,
|
||||
@@ -1634,12 +1634,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
if not (is_flax_available() and is_transformers_available() and is_transformers_flax_compatible()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import (
|
||||
FlaxDiffusionPipeline,
|
||||
FlaxStableDiffusionControlNetPipeline,
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
|
||||
@@ -15,16 +15,23 @@
|
||||
import json
|
||||
|
||||
import torch
|
||||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoTokenizer, Mistral3Model
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...utils import logging
|
||||
from ...utils.import_utils import is_transformers_version
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import ErnieImageModularPipeline
|
||||
|
||||
|
||||
if is_transformers_version("<", "5.0.0"):
|
||||
raise ImportError("`ErnieImageModularPipeline` requires `transformers>=5.0.0` for `Ministral3ForCausalLM`.")
|
||||
|
||||
from transformers import Ministral3ForCausalLM # noqa: E402
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -38,7 +45,7 @@ class ErnieImagePromptEnhancerStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pe", AutoModelForCausalLM),
|
||||
ComponentSpec("pe", Ministral3ForCausalLM),
|
||||
ComponentSpec("pe_tokenizer", AutoTokenizer),
|
||||
]
|
||||
|
||||
@@ -83,7 +90,7 @@ class ErnieImagePromptEnhancerStep(ModularPipelineBlocks):
|
||||
|
||||
@staticmethod
|
||||
def _enhance_prompt(
|
||||
pe: AutoModelForCausalLM,
|
||||
pe: Ministral3ForCausalLM,
|
||||
pe_tokenizer: AutoTokenizer,
|
||||
prompt: str,
|
||||
device: torch.device,
|
||||
@@ -160,7 +167,7 @@ class ErnieImageTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", AutoModel),
|
||||
ComponentSpec("text_encoder", Mistral3Model),
|
||||
ComponentSpec("tokenizer", AutoTokenizer),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
@@ -200,7 +207,7 @@ class ErnieImageTextEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
@staticmethod
|
||||
def _encode(
|
||||
text_encoder: AutoModel,
|
||||
text_encoder: Mistral3Model,
|
||||
tokenizer: AutoTokenizer,
|
||||
prompt: list[str],
|
||||
device: torch.device,
|
||||
|
||||
@@ -5,7 +5,6 @@ from ..utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
@@ -14,6 +13,7 @@ from ..utils import (
|
||||
is_torch_available,
|
||||
is_torch_npu_available,
|
||||
is_transformers_available,
|
||||
is_transformers_flax_compatible,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
@@ -504,7 +504,7 @@ else:
|
||||
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
if not is_transformers_flax_compatible():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_flax_objects # noqa F403
|
||||
@@ -513,7 +513,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"]
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
if not is_transformers_flax_compatible():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
@@ -930,7 +930,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .consisid import ConsisIDPipeline
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
if not is_transformers_flax_compatible():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_objects import * # noqa F403
|
||||
@@ -938,7 +938,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
if not is_transformers_flax_compatible():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_and_transformers_objects import *
|
||||
|
||||
@@ -5,9 +5,9 @@ from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_flax_compatible,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ else:
|
||||
_import_structure["pipeline_controlnet_union_sd_xl"] = ["StableDiffusionXLControlNetUnionPipeline"]
|
||||
_import_structure["pipeline_controlnet_union_sd_xl_img2img"] = ["StableDiffusionXLControlNetUnionImg2ImgPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
if not is_transformers_flax_compatible():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
@@ -65,7 +65,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_controlnet_union_sd_xl_img2img import StableDiffusionXLControlNetUnionImg2ImgPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
if not is_transformers_flax_compatible():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
|
||||
@@ -5,9 +5,9 @@ from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_flax_compatible,
|
||||
)
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_stable_diffusion_3_controlnet_inpainting import StableDiffusion3ControlNetInpaintingPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
if not is_transformers_flax_compatible():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
|
||||
@@ -5,9 +5,9 @@ from ....utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_flax_compatible,
|
||||
)
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ else:
|
||||
_import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"]
|
||||
_import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
if not is_transformers_flax_compatible():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ....utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
@@ -47,7 +47,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
if not is_transformers_flax_compatible():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ....utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
|
||||
@@ -20,7 +20,7 @@ import json
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoTokenizer, Mistral3Model
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import ErnieImageLoraLoaderMixin
|
||||
@@ -28,10 +28,17 @@ from ...models import AutoencoderKLFlux2
|
||||
from ...models.transformers import ErnieImageTransformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils.import_utils import is_transformers_version
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import ErnieImagePipelineOutput
|
||||
|
||||
|
||||
if is_transformers_version("<", "5.0.0"):
|
||||
raise ImportError("`ErnieImagePipeline` requires `transformers>=5.0.0` for `Ministral3ForCausalLM`.")
|
||||
|
||||
from transformers import Ministral3ForCausalLM # noqa: E402
|
||||
|
||||
|
||||
class ErnieImagePipeline(DiffusionPipeline, ErnieImageLoraLoaderMixin):
|
||||
"""
|
||||
Pipeline for text-to-image generation using ErnieImageTransformer2DModel.
|
||||
@@ -52,10 +59,10 @@ class ErnieImagePipeline(DiffusionPipeline, ErnieImageLoraLoaderMixin):
|
||||
self,
|
||||
transformer: ErnieImageTransformer2DModel,
|
||||
vae: AutoencoderKLFlux2,
|
||||
text_encoder: AutoModel,
|
||||
text_encoder: Mistral3Model,
|
||||
tokenizer: AutoTokenizer,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
pe: Optional[AutoModelForCausalLM] = None,
|
||||
pe: Optional[Ministral3ForCausalLM] = None,
|
||||
pe_tokenizer: Optional[AutoTokenizer] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -36,12 +36,12 @@ from ..utils import (
|
||||
BaseOutput,
|
||||
PushToHubMixin,
|
||||
http_user_agent,
|
||||
is_transformers_available,
|
||||
is_transformers_flax_compatible,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
if is_transformers_flax_compatible():
|
||||
from transformers import FlaxPreTrainedModel
|
||||
|
||||
INDEX_FILE = "diffusion_flax_model.bin"
|
||||
@@ -501,7 +501,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
dtype=dtype,
|
||||
)
|
||||
params[name] = loaded_params
|
||||
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
|
||||
elif is_transformers_flax_compatible() and issubclass(class_obj, FlaxPreTrainedModel):
|
||||
if from_pt:
|
||||
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
|
||||
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
|
||||
|
||||
@@ -5,10 +5,10 @@ from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_onnx_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_flax_compatible,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ _dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["StableDiffusionPipelineOutput"]}
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
if is_transformers_flax_compatible():
|
||||
_import_structure["pipeline_output"].extend(["FlaxStableDiffusionPipelineOutput"])
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -82,7 +82,7 @@ else:
|
||||
_import_structure["pipeline_onnx_stable_diffusion_inpaint_legacy"] = ["OnnxStableDiffusionInpaintPipelineLegacy"]
|
||||
_import_structure["pipeline_onnx_stable_diffusion_upscale"] = ["OnnxStableDiffusionUpscalePipeline"]
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
if is_transformers_flax_compatible():
|
||||
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
|
||||
|
||||
_additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState})
|
||||
@@ -162,7 +162,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
if not is_transformers_flax_compatible():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_objects import *
|
||||
|
||||
@@ -5,9 +5,9 @@ from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_flax_compatible,
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ _dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["StableDiffusionXLPipelineOutput"]}
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
if is_transformers_flax_compatible():
|
||||
_import_structure["pipeline_output"].extend(["FlaxStableDiffusionXLPipelineOutput"])
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -30,7 +30,7 @@ else:
|
||||
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"]
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
if is_transformers_flax_compatible():
|
||||
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
|
||||
|
||||
_additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState})
|
||||
@@ -50,7 +50,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
if not is_transformers_flax_compatible():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_objects import *
|
||||
|
||||
@@ -122,6 +122,7 @@ from .import_utils import (
|
||||
is_torchsde_available,
|
||||
is_torchvision_available,
|
||||
is_transformers_available,
|
||||
is_transformers_flax_compatible,
|
||||
is_transformers_version,
|
||||
is_unidecode_available,
|
||||
is_wandb_available,
|
||||
|
||||
@@ -2,6 +2,21 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class FlaxDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["flax", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax", "transformers"])
|
||||
|
||||
|
||||
class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["flax", "transformers"]
|
||||
|
||||
|
||||
@@ -62,21 +62,6 @@ class FlaxAutoencoderKL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxDDIMScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
||||
92
src/diffusers/utils/dummy_transformers_flax_objects.py
Normal file
92
src/diffusers/utils/dummy_transformers_flax_objects.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class FlaxDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["transformers_flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers_flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["transformers_flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["transformers_flax"])
|
||||
|
||||
|
||||
class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["transformers_flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers_flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["transformers_flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["transformers_flax"])
|
||||
|
||||
|
||||
class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["transformers_flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers_flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["transformers_flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["transformers_flax"])
|
||||
|
||||
|
||||
class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["transformers_flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers_flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["transformers_flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["transformers_flax"])
|
||||
|
||||
|
||||
class FlaxStableDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["transformers_flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers_flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["transformers_flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["transformers_flax"])
|
||||
|
||||
|
||||
class FlaxStableDiffusionXLPipeline(metaclass=DummyObject):
|
||||
_backends = ["transformers_flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers_flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["transformers_flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["transformers_flax"])
|
||||
@@ -258,6 +258,22 @@ def is_transformers_available():
|
||||
return _transformers_available
|
||||
|
||||
|
||||
def is_transformers_flax_compatible():
|
||||
# Flax classes (e.g. FlaxCLIPTextModel, FlaxPreTrainedModel) were removed from
|
||||
# transformers main on the path to its v5 release. Gate Flax pipeline registration
|
||||
# on transformers still shipping them so `import diffusers` doesn't crash.
|
||||
# Name avoids the `is_*_available()` pattern so utils/check_dummies.py keeps
|
||||
# generating the `flax_and_transformers` backend group when this is combined with
|
||||
# the legacy is_flax_available()/is_transformers_available() pair.
|
||||
if not (_transformers_available and _flax_available):
|
||||
return False
|
||||
try:
|
||||
import transformers
|
||||
except ImportError:
|
||||
return False
|
||||
return hasattr(transformers, "FlaxPreTrainedModel")
|
||||
|
||||
|
||||
def is_inflect_available():
|
||||
return _inflect_available
|
||||
|
||||
|
||||
Reference in New Issue
Block a user