Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 107 additions & 29 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
is_transformers_available,
)


# Lazy Import based on
# https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py

Expand Down Expand Up @@ -60,7 +59,11 @@
}

try:
if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
if (
not is_torch_available()
and not is_accelerate_available()
and not is_bitsandbytes_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_bitsandbytes_objects
Expand All @@ -72,7 +75,11 @@
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")

try:
if not is_torch_available() and not is_accelerate_available() and not is_gguf_available():
if (
not is_torch_available()
and not is_accelerate_available()
and not is_gguf_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_gguf_objects
Expand All @@ -84,7 +91,11 @@
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")

try:
if not is_torch_available() and not is_accelerate_available() and not is_torchao_available():
if (
not is_torch_available()
and not is_accelerate_available()
and not is_torchao_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torchao_objects
Expand All @@ -96,7 +107,11 @@
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")

try:
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
if (
not is_torch_available()
and not is_accelerate_available()
and not is_optimum_quanto_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_optimum_quanto_objects
Expand Down Expand Up @@ -126,7 +141,9 @@
except OptionalDependencyNotAvailable:
from .utils import dummy_pt_objects # noqa F403

_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
_import_structure["utils.dummy_pt_objects"] = [
name for name in dir(dummy_pt_objects) if not name.startswith("_")
]

else:
_import_structure["hooks"].extend(
Expand Down Expand Up @@ -187,6 +204,7 @@
"OmniGenTransformer2DModel",
"PixArtTransformer2DModel",
"PriorTransformer",
"SanaControlNetModel",
"SanaTransformer2DModel",
"SD3ControlNetModel",
"SD3MultiControlNetModel",
Expand Down Expand Up @@ -303,11 +321,15 @@
from .utils import dummy_torch_and_torchsde_objects # noqa F403

_import_structure["utils.dummy_torch_and_torchsde_objects"] = [
name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_")
name
for name in dir(dummy_torch_and_torchsde_objects)
if not name.startswith("_")
]

else:
_import_structure["schedulers"].extend(["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"])
_import_structure["schedulers"].extend(
["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"]
)

try:
if not (is_torch_available() and is_transformers_available()):
Expand All @@ -316,7 +338,9 @@
from .utils import dummy_torch_and_transformers_objects # noqa F403

_import_structure["utils.dummy_torch_and_transformers_objects"] = [
name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
name
for name in dir(dummy_torch_and_transformers_objects)
if not name.startswith("_")
]

else:
Expand Down Expand Up @@ -424,6 +448,7 @@
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
"SanaPAGPipeline",
"SanaPipeline",
"SanaSprintPipeline",
Expand Down Expand Up @@ -517,39 +542,63 @@
)

try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
if not (
is_torch_available()
and is_transformers_available()
and is_k_diffusion_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403

_import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
name
for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects)
if not name.startswith("_")
]

else:
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
_import_structure["pipelines"].extend(
["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"]
)

try:
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
if not (
is_torch_available()
and is_transformers_available()
and is_sentencepiece_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
from .utils import ( # noqa F403
dummy_torch_and_transformers_and_sentencepiece_objects,
)

_import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [
name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_")
_import_structure[
"utils.dummy_torch_and_transformers_and_sentencepiece_objects"
] = [
name
for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects)
if not name.startswith("_")
]

else:
_import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"])
_import_structure["pipelines"].extend(
["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"]
)

try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
if not (
is_torch_available() and is_transformers_available() and is_onnx_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403

_import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
name
for name in dir(dummy_torch_and_transformers_and_onnx_objects)
if not name.startswith("_")
]

else:
Expand All @@ -571,20 +620,26 @@
from .utils import dummy_torch_and_librosa_objects # noqa F403

_import_structure["utils.dummy_torch_and_librosa_objects"] = [
name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_")
name
for name in dir(dummy_torch_and_librosa_objects)
if not name.startswith("_")
]

else:
_import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"])

try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
if not (
is_transformers_available() and is_torch_available() and is_note_seq_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403

_import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")
name
for name in dir(dummy_transformers_and_torch_and_note_seq_objects)
if not name.startswith("_")
]


Expand All @@ -605,7 +660,9 @@
else:
_import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.unets.unet_2d_condition_flax"] = [
"FlaxUNet2DConditionModel"
]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend(
Expand All @@ -630,7 +687,9 @@
from .utils import dummy_flax_and_transformers_objects # noqa F403

_import_structure["utils.dummy_flax_and_transformers_objects"] = [
name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_")
name
for name in dir(dummy_flax_and_transformers_objects)
if not name.startswith("_")
]


Expand Down Expand Up @@ -763,6 +822,7 @@
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SanaControlNetModel,
SanaTransformer2DModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
Expand Down Expand Up @@ -979,6 +1039,7 @@
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
ReduxImageEncoder,
SanaControlNetPipeline,
SanaPAGPipeline,
SanaPipeline,
SanaSprintPipeline,
Expand Down Expand Up @@ -1070,22 +1131,35 @@
)

try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
if not (
is_torch_available()
and is_transformers_available()
and is_k_diffusion_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
else:
from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
from .pipelines import (
StableDiffusionKDiffusionPipeline,
StableDiffusionXLKDiffusionPipeline,
)

try:
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
if not (
is_torch_available()
and is_transformers_available()
and is_sentencepiece_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
else:
from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
if not (
is_torch_available() and is_transformers_available() and is_onnx_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
Expand All @@ -1108,7 +1182,11 @@
from .pipelines import AudioDiffusionPipeline, Mel

try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
if not (
is_transformers_available()
and is_torch_available()
and is_note_seq_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
Expand Down
Loading